diff --git a/.github/workflows/build_cython_extensions.yml b/.github/workflows/build_cython_extensions.yml new file mode 100644 index 000000000..9853df048 --- /dev/null +++ b/.github/workflows/build_cython_extensions.yml @@ -0,0 +1,111 @@ +name: Build Cython extensions + +permissions: + contents: write + +on: + push: + paths: + - "cellacdc/**/*.pyx" + - "precompile_functions.py" + workflow_dispatch: + +jobs: + build: + runs-on: ${{ matrix.os }} + defaults: + run: + working-directory: ${{ github.workspace }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Build extension + run: | + python -m pip install --upgrade "setuptools>=77" cython numpy + # Remove existing binaries so only the freshly built one remains for upload + python -c " + import os, glob + for f in glob.glob('cellacdc/precompiled/*.so') + glob.glob('cellacdc/precompiled/*.pyd'): + os.remove(f) + print(f'Removed existing: {f}') + " + python "${{ github.workspace }}/precompile_functions.py" build_ext --inplace --build-temp "${{ github.workspace }}/build/temp" + echo "Built binary:" + ls cellacdc/precompiled/ + + - uses: actions/upload-artifact@v6 + with: + name: precompiled-${{ matrix.os }}-py${{ matrix.python-version }} + path: | + cellacdc/precompiled/*.so + cellacdc/precompiled/*.pyd + + commit: + needs: build + runs-on: ubuntu-latest + permissions: + contents: write + defaults: + run: + working-directory: ${{ github.workspace }} + steps: + - uses: actions/checkout@v6 + with: + ref: ${{ github.ref_name }} + fetch-depth: 0 + + - uses: actions/download-artifact@v6 + with: + path: artifacts/ + + - name: Flatten artifacts into precompiled directory + run: | + mkdir -p cellacdc/precompiled + find artifacts -type f \( -name "*.pyd" -o -name "*.so" \) -exec cp {} cellacdc/precompiled/ \; + echo "Files copied:" + ls -la cellacdc/precompiled/ + echo "" + echo "Total files:" + find cellacdc/precompiled/ -type f | grep -E '\.(pyd|so)$' | wc -l + + - name: Validate artifact count + run: | + EXPECTED_COUNT=12 + FILE_COUNT=$(find cellacdc/precompiled/ -type f \( -name "*.pyd" -o -name "*.so" \) | wc -l) + echo "Expected files: $EXPECTED_COUNT" + echo "Found files: $FILE_COUNT" + if [ "$FILE_COUNT" -lt "$EXPECTED_COUNT" ]; then + echo "Missing binaries: expected at least $EXPECTED_COUNT, found $FILE_COUNT" + exit 1 + fi + + - name: Commit precompiled binaries + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + find cellacdc/precompiled/ -type f \( -name "*.pyd" -o -name "*.so" \) -print0 | xargs -0 git add -f + git add cellacdc/precompiled/__init__.py + FILE_COUNT=$(find cellacdc/precompiled/ -type f \( -name "*.pyd" -o -name "*.so" \) | wc -l) + echo "Files to commit: $FILE_COUNT" + if [ "$FILE_COUNT" -eq 0 ]; then + echo "No precompiled files found, skipping commit" + exit 0 + fi + if git diff --cached --quiet; then + echo "No changes to commit" + exit 0 + fi + git commit -m "ci: update precompiled Cython extensions [skip ci]" + git pull --rebase origin ${{ github.ref_name }} + git push origin HEAD:${{ github.ref_name }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6d90763c1..3b45e534f 100755 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,13 @@ requirements_new.txt **/weights_location_path.txt **/_test +# cython generated files +*.pyd +*.so +!cellacdc/precompiled/ +!cellacdc/precompiled/*.pyd +!cellacdc/precompiled/*.so + # Test output plots tests/_plots/ @@ -36,7 +43,6 @@ setup.cfg # Starting from pip 21.3 setup.py is not needed anymore # and we rely only on setup.cfg for env -setup.py environment.yml # requirements.txt conda_env_list_commands.txt diff --git a/cellacdc/__init__.py b/cellacdc/__init__.py index c4fcaf553..15355abdb 100755 --- a/cellacdc/__init__.py +++ b/cellacdc/__init__.py @@ -806,4 +806,17 @@ def inner_function(self, *args, **kwargs): single_pos_index_cols = ( 'experiment_folderpath', 'Position_n' -) \ No newline at end of file +) + +precompiled_import_success = False +try: + from cellacdc.precompiled.precompiled_functions import ( + find_all_objects_2D, + find_all_objects_3D, + calc_IoA_matrix_2D, + calc_IoA_matrix_3D, + most_common_projection_3D, + ) + precompiled_import_success = True +except Exception: + pass \ No newline at end of file diff --git a/cellacdc/_warnings.py b/cellacdc/_warnings.py index e313b3472..e5e2042f3 100644 --- a/cellacdc/_warnings.py +++ b/cellacdc/_warnings.py @@ -32,6 +32,34 @@ def warnTooManyItems(mainWin, numItems, qparent): ) return msg.cancel, msg.clickedButton==switchToLowResButton +def warnTooManyNewItems(mainWin, numItems, qparent): + from . import widgets + mainWin.logger.info( + '[WARNING]: asking user what to do with too many objects...' + ) + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(f""" + WARNING: The resulting segmentation mask has {numItems} objects.

+ Creating high resolution text annotations + for these many objects could take a long time.

+ We recommend deactivating text annotations or switching to low resolution annotations.

+ You can still try to switch to activate them or switch to high resolution later.

+ What do you want to do? + """) + + _, switchToLowResButton, deactivateAnnotButton = msg.warning( + qparent, 'Too many objects', txt, + buttonsTexts=( + 'Cancel', + widgets.reloadPushButton(' Switch to low resolution '), + widgets.noPushButton(' Deactivate text annotations ') + ) + ) + switchToLowRes = msg.clickedButton==switchToLowResButton + deactivateAnnot = msg.clickedButton==deactivateAnnotButton + + return msg.cancel, switchToLowRes, deactivateAnnot + def warnRestartCellACDCcolorModeToggled( scheme, app_name='Cell-ACDC', parent=None ): diff --git a/cellacdc/annotate.py b/cellacdc/annotate.py index ec09b3610..cbe9f2c97 100644 --- a/cellacdc/annotate.py +++ b/cellacdc/annotate.py @@ -5,7 +5,7 @@ import pandas as pd from . import GUI_INSTALLED -from . import cellacdc_path, printl, ignore_exception +from . import cellacdc_path, printl, ignore_exception, debugutils if GUI_INSTALLED: from PIL import Image, ImageFont, ImageDraw @@ -169,9 +169,10 @@ def appendData(self, data, text): self.annotData.append(data) self.texts.append(text) - def highlightObject(self, obj): + def highlightObject(self, obj, rp=None, getObjCentroidFunc=None): self.highlighterItem.texts = self.texts - self.highlighterItem.highlightObject(obj) + self.highlighterItem.highlightObject( + obj, rp=rp, getObjCentroidFunc=getObjCentroidFunc) def grayOutAnnotations(self, IDsToSkip=None): self.setOpacity(0.3) @@ -207,6 +208,45 @@ def __init__(self, *args, anchor=(0.5, 0.5), **kargs): self.texts = [] self.annotData = [] self._anchor = anchor + + def _rebuildSizes(self, bold=False): + if bold: + self.sizesBold = plot.get_symbol_sizes( + self.scalesBold, self.symbolsBold, self.fontSize + ) + self._maxScaleBold = max(self.scalesBold.values(), default=None) + else: + self.sizesRegular = plot.get_symbol_sizes( + self.scalesRegular, self.symbolsRegular, self.fontSize + ) + self._maxScaleRegular = max(self.scalesRegular.values(), default=None) + + def _updateSizesForTexts(self, texts, bold=False): + if not texts: + return + + if bold: + scales = self.scalesBold + sizes_attr = 'sizesBold' + max_scale_attr = '_maxScaleBold' + else: + scales = self.scalesRegular + sizes_attr = 'sizesRegular' + max_scale_attr = '_maxScaleRegular' + + current_max_scale = getattr(self, max_scale_attr, None) + if current_max_scale is None: + self._rebuildSizes(bold=bold) + return + + added_max_scale = max(scales[text] for text in texts) + if added_max_scale > current_max_scale: + self._rebuildSizes(bold=bold) + return + + sizes = getattr(self, sizes_attr) + for text in texts: + sizes[text] = int(np.round(self.fontSize*current_max_scale/scales[text])) def clearData(self): self.setData([], []) @@ -254,15 +294,28 @@ def initSymbols(self, allIDs, onlyIDs=False): self.createSymbols(annotTexts) def addSymbols(self, annotTexts, includeBold=True): - for text in annotTexts: - if includeBold: - self.symbolsBold[text] = self.getObjTextAnnotSymbol( - text, bold=True, initSizes=False + if includeBold: + missing_bold = [ + text for text in annotTexts if text not in self.symbolsBold + ] + if missing_bold: + symbolsBold, scalesBold = plot.texts_to_pg_scatter_symbols( + missing_bold, font=self.fontBold, return_scales=True ) - self.symbolsRegular[text] = self.getObjTextAnnotSymbol( - text, bold=True, initSizes=False + self.symbolsBold.update(symbolsBold) + self.scalesBold.update(scalesBold) + self._updateSizesForTexts(missing_bold, bold=True) + + missing_regular = [ + text for text in annotTexts if text not in self.symbolsRegular + ] + if missing_regular: + symbolsRegular, scalesRegular = plot.texts_to_pg_scatter_symbols( + missing_regular, font=self.fontRegular, return_scales=True ) - self.initSizes(includeBold=includeBold) + self.symbolsRegular.update(symbolsRegular) + self.scalesRegular.update(scalesRegular) + self._updateSizesForTexts(missing_regular, bold=False) def createSymbols(self, annotTexts, includeBold=True): if includeBold: @@ -281,12 +334,8 @@ def initSizes(self, includeBold=True): includeBold = False if includeBold: - self.sizesBold = plot.get_symbol_sizes( - self.scalesBold, self.symbolsBold, self.fontSize - ) - self.sizesRegular = plot.get_symbol_sizes( - self.scalesRegular, self.symbolsRegular, self.fontSize - ) + self._rebuildSizes(bold=True) + self._rebuildSizes(bold=False) def setColors(self, colors): self._colors = colors.copy() @@ -325,7 +374,7 @@ def getObjTextAnnotSymbol(self, text, bold=False, initSizes=True): symbols[text] = symbol scales[text] = scale if initSizes: - self.initSizes() + self._updateSizesForTexts([text], bold=bold) return symbol def grayOutAnnotations(self, IDsToSkip=None): @@ -346,7 +395,7 @@ def grayOutAnnotations(self, IDsToSkip=None): self.setBrush(brushes) self.setPen(pens) - def highlightObject(self, obj): + def highlightObject(self, obj, rp=None, getObjCentroidFunc=None): ID = obj.label objIdx = None for idx, objData in enumerate(self.data): @@ -357,7 +406,14 @@ def highlightObject(self, obj): objOpts = { 'text': str(ID), 'bold': True, 'color_name': 'new_object' } - yc, xc = obj.centroid[-2:] + if rp is not None: + centroid = rp.get_centroid(obj.label) + else: + centroid = obj.centroid + if getObjCentroidFunc is not None: + yc, xc = getObjCentroidFunc(centroid) + else: + yc, xc = centroid[-2:] pos = (int(xc), int(yc)) self.addObjAnnot(pos, draw=True, **objOpts) return @@ -504,13 +560,20 @@ def removeFromPlotItem(self, ax): if hasattr(self.item, 'highlighterItem'): ax.removeItem(self.item.highlighterItem) - def addObjAnnotation(self, obj, color_name, text, bold): + def addObjAnnotation(self, obj, color_name, text, bold, rp=None, getObjCentroidFunc=None): objOpts = { 'text': text, 'bold': bold, 'color_name': color_name, } - yc, xc = obj.centroid[-2:] + if rp is not None: + centroid = rp.get_centroid(obj.label) + else: + centroid = obj.centroid + if getObjCentroidFunc is not None: + yc, xc = getObjCentroidFunc(centroid) + else: + yc, xc = centroid[-2:] pos = (int(xc), int(yc)) objData = self.item.addObjAnnot(pos, draw=True, **objOpts) self.item.appendData(objData, objOpts['text']) @@ -563,7 +626,8 @@ def setAnnotations(self, **kwargs): isGenNumTreeAnnotation, posData.frame_i ) - yc, xc = getObjCentroidFunc(obj.centroid) + centroid = posData.rp.get_centroid(obj.label) + yc, xc = getObjCentroidFunc(centroid) try: rp_zslice = posData.zSlicesRp[currentZ] obj_2d = rp_zslice[obj.label] @@ -598,7 +662,8 @@ def setAnnotations(self, **kwargs): 'color_name': 'tracked_lost_object', 'bold': False, } - yc, xc = obj.centroid[-2:] + centroid = prev_rp.get_centroid(obj.label) + yc, xc = getObjCentroidFunc(centroid) pos = (int(xc), int(yc)) objData = self.item.addObjAnnot(pos, draw=False, **objOpts) self.item.appendData(objData, objOpts['text']) @@ -624,7 +689,8 @@ def setAnnotations(self, **kwargs): 'color_name': 'lost_object', 'bold': False, } - yc, xc = getObjCentroidFunc(obj.centroid) + centroid = prev_rp.get_centroid(obj.label) + yc, xc = getObjCentroidFunc(centroid) try: pos = (int(xc), int(yc)) except Exception as err: @@ -638,8 +704,8 @@ def setAnnotations(self, **kwargs): self.item.draw() - def highlightObject(self, obj): - self.item.highlightObject(obj) + def highlightObject(self, obj, rp=None, getObjCentroidFunc=None): + self.item.highlightObject(obj, rp=rp, getObjCentroidFunc=getObjCentroidFunc) def removeHighlightObject(self, obj): self.item.removeHighlightObject(obj) diff --git a/cellacdc/apps.py b/cellacdc/apps.py index 5bc4553c0..dc4cd369b 100755 --- a/cellacdc/apps.py +++ b/cellacdc/apps.py @@ -94,18 +94,13 @@ from . import io from . import cca_functions from . import path +from . import fonts POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() BACKGROUND_RGBA = _palettes.get_disabled_colors()['Button'] -font = QFont() -font.setPixelSize(12) -italicFont = QFont() -italicFont.setPixelSize(12) -italicFont.setItalic(True) - class ArgWidget: def __init__(self, name, type, widget, defaultVal, valueSetter, valueGetter, changeSig=None): self.name = name @@ -838,7 +833,7 @@ def __init__( self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def addCentroidsSection(self, row, layout, **kwargs): sectionWidgets = [] @@ -1441,7 +1436,7 @@ def __init__(self, parent=None): self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def restoreState(self, state): self.appearanceGroupbox.restoreState(state) @@ -1588,7 +1583,7 @@ def __init__( layout.addLayout(buttonsLayout) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) if defaultEntry: self.updateFilename(defaultEntry) @@ -1856,7 +1851,7 @@ def __init__(self, basename='', parent=None): mainLayout.addLayout(buttonsLayout) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def createThirdSegmToggled(self, checked): self.appendTextWidget.setDisabled(not checked) @@ -3026,9 +3021,7 @@ def __init__( self.imageViewer = None super().__init__(parent) self.setWindowTitle(title) - font = QFont() - font.setPixelSize(12) - self.setFont(font) + self.setFont(fonts.font) mainLayout = QVBoxLayout() entriesLayout = QGridLayout() @@ -3922,7 +3915,7 @@ def __init__(self, parent=None): layout.addLayout(buttonsLayout) layout.addStretch(1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def showInfo(self): msg = widgets.myMessageBox(wrapText=False) @@ -4114,7 +4107,7 @@ def __init__( layout.addLayout(buttonsLayout) layout.addStretch(1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def selectFeatures(self): features = measurements.get_btrack_features() @@ -4340,7 +4333,7 @@ def __init__(self, posData=None, parent=None): layout.addLayout(buttonsLayout) layout.addStretch(1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def methodChanged(self, method): if method == 'mothermachine': @@ -4565,7 +4558,7 @@ def __init__( self.loop = None self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self, checked=False): self.cancel = False @@ -4741,7 +4734,7 @@ def __init__(self, fileName, folderPath, readPatternFunc=None, parent=None): self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def segmFolderpathSelected(self, path): self.segmFolderPathEntry.setText(path) @@ -4941,7 +4934,7 @@ def __init__(self, parent=None, isSegm3D=True): cancelButton.clicked.connect(self.close) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) self.configPars = self.loadLastSelection() @@ -5102,7 +5095,7 @@ def __init__(self, df: pd.DataFrame, parent=None): ) self.setLayout(self.mainLayout) - self.setFont(font) + self.setFont(fonts.font) def saveSelection(self): saved_selections = io.get_saved_moth_bud_tot_selections() @@ -5507,7 +5500,7 @@ def __init__(self, df, parent=None): self.mainLayout.addLayout(buttonsLayout) self.setLayout(self.mainLayout) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self): self.cancel = False @@ -5577,50 +5570,66 @@ def ok_cb(self): self.model_name = self.listBox.currentItem().text() self.close() - class QDialogSelectModel(QDialog): def __init__( - self, parent=None, addSkipSegmButton=False, customFirst='' + self, parent=None, addSkipSegmButton=False, customFirst='', + allowMultiSelection=False, lastSelection=None, + addSelectLastSelectionButton=False, + addSelectLastRecipeButton=False, + custom_title=None, + info_label='', ): self.cancel = True + self.loadLastRecipe = False super().__init__(parent) self.setWindowTitle('Select model') + self.info_label = info_label - mainLayout = QVBoxLayout() - topLayout = QVBoxLayout() - bottomLayout = QHBoxLayout() + self.allowMultiSelection = allowMultiSelection + self.lastSelection = [] + for m in (lastSelection or []): + if not isinstance(m, str): + continue + if m == 'thresholding': + m = 'Automatic thresholding' + self.lastSelection.append(m) + mainLayout = QVBoxLayout() self.mainLayout = mainLayout + title = custom_title or 'Select model to use for segmentation: ' + + titleContainer = QWidget(self) + titleLayout = QGridLayout(titleContainer) + titleLayout.setContentsMargins(0, 0, 0, 0) + titleLayout.setSpacing(0) label = QLabel(html_utils.paragraph( - 'Select model to use for segmentation: ' + title )) - # padding: top, left, bottom, right label.setStyleSheet("padding:0px 0px 3px 0px;") - topLayout.addWidget(label, alignment=Qt.AlignCenter) - - listBox = widgets.listWidget() - models = myutils.get_list_of_models() + titleLayout.addWidget(label, 0, 0, Qt.AlignCenter) + if info_label: + moreInfoButton = widgets.infoPushButton() + moreInfoButton.clicked.connect(self.showInfoLabel) + moreInfoButton.setSizePolicy( + QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed + ) + titleLayout.addWidget(moreInfoButton, 0, 0, Qt.AlignTop | Qt.AlignRight) + mainLayout.addWidget(titleContainer) - if customFirst: - try: - idx = models.index(customFirst) - models.insert(0, models.pop(idx)) - except ValueError: - print(f'Warning: {customFirst} not found in models list.') - pass + self.modelSelector = widgets.ModelSelectionWidget( + parent=self, + customFirst=customFirst, + allowMultiSelection=allowMultiSelection, + ) + # Convenience aliases kept for backward compatibility + self.listBox = self.modelSelector.listBox + mainLayout.addWidget(self.modelSelector) - listBox.setFont(font) - listBox.addItems(models) - addCustomModelItem = QListWidgetItem('Add custom model...') - addCustomModelItem.setFont(italicFont) - listBox.addItem(addCustomModelItem) - listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) - listBox.setCurrentRow(0) - self.listBox = listBox - listBox.itemDoubleClicked.connect(self.ok_cb) - topLayout.addWidget(listBox) + if not allowMultiSelection: + self.listBox.itemDoubleClicked.connect(self.ok_cb) + bottomLayout = QHBoxLayout() cancelButton = widgets.cancelPushButton('Cancel') okButton = widgets.okPushButton(' Ok ') okButton.setShortcut(Qt.Key_Enter) @@ -5632,53 +5641,178 @@ def __init__( skipSegmButton = widgets.SkipPushButton('Skip segmentation') bottomLayout.addWidget(skipSegmButton) skipSegmButton.clicked.connect(self.skipSegm) + if addSelectLastSelectionButton and allowMultiSelection: + selectLastSelButton = widgets.reloadPushButton('Load last selection...') + selectLastSelButton.clicked.connect(self.selectLastSelection) + selectLastSelButton.setEnabled(bool(self.lastSelection)) + bottomLayout.addWidget(selectLastSelButton) + if addSelectLastRecipeButton and allowMultiSelection: + selectLastRecipeButton = widgets.reloadPushButton('Load last recipe...') + selectLastRecipeButton.clicked.connect(self.selectLastRecipe) + selectLastRecipeButton.setEnabled(bool(self.lastSelection)) + bottomLayout.addWidget(selectLastRecipeButton) + if allowMultiSelection: + addCustomModelButton = widgets.addPushButton('Add custom model...') + addCustomModelButton.clicked.connect(self.addCustomModel) + bottomLayout.addWidget(addCustomModelButton) bottomLayout.addWidget(okButton) bottomLayout.setContentsMargins(0, 10, 0, 0) - mainLayout.addLayout(topLayout) mainLayout.addLayout(bottomLayout) self.setLayout(mainLayout) - # Connect events okButton.clicked.connect(self.ok_cb) cancelButton.clicked.connect(self.cancel_cb) - self.setStyleSheet(LISTWIDGET_STYLESHEET) - + + @property + def selectionSequence(self): + return self.modelSelector.selectionSequence + + @property + def modelItemsMap(self): + return self.modelSelector.modelItemsMap + def skipSegm(self): self.cancel = False self.selectedModel = 'skip_segmentation' self.close() - + + def selectLastSelection(self): + if not self.lastSelection: + return + self.modelSelector.setSelectionFromList(self.lastSelection) + + def selectLastRecipe(self): + if not self.lastSelection: + return + self.selectLastSelection() + self.cancel = False + self.loadLastRecipe = True + self.selectedModel = self.lastSelection.copy() + self.close() + + def _runAddCustomModelWorkflow(self): + modelFilePath = addCustomModelMessages(self) + if modelFilePath is None: + return None + + myutils.store_custom_model_path(modelFilePath) + modelName = os.path.basename(os.path.dirname(modelFilePath)) + self.modelSelector.registerCustomModel(modelName) + return modelName + + def addCustomModel(self): + modelName = self._runAddCustomModelWorkflow() + if modelName is None: + return + + if self.allowMultiSelection: + self.modelSelector.addModelSelection(modelName) + else: + item = QListWidgetItem(modelName) + self.listBox.addItem(item) + self.listBox.setCurrentItem(item) + + def showInfoLabel(self): + if not self.info_label: + return + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(self.info_label) + msg.information(self, 'More info', txt) + def keyPressEvent(self, event: QKeyEvent) -> None: if event.key() == Qt.Key_Escape: event.ignore() return - super().keyPressEvent(event) - def ok_cb(self, event): + def askSelectedModelsOrder(self, selected_models): + dialog = QDialog(self) + dialog.setWindowTitle('Order selected models') + dialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + layout = QVBoxLayout(dialog) + infoTxt = html_utils.paragraph( + 'Drag and drop to change the order of the selected models.
' + 'The top model will run first.' + ) + layout.addWidget(QLabel(infoTxt)) + + modelOrderView = widgets.ReorderableListView( + selected_models, parent=dialog, isSingleSelection=True + ) + layout.addWidget(modelOrderView) + + buttonsLayout = QHBoxLayout() + cancelButton = widgets.cancelPushButton('Cancel') + okButton = widgets.okPushButton('Ok') + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + layout.addLayout(buttonsLayout) + + cancelButton.clicked.connect(dialog.reject) + okButton.clicked.connect(dialog.accept) + + if dialog.exec_() != QDialog.Accepted: + return None + + return modelOrderView.items() + + def ok_cb(self, event=None): self.clickedButton = self.sender() - self.cancel = False - item = self.listBox.currentItem() + + if self.allowMultiSelection: + if not self.selectionSequence: + return + + selected_models = list(self.selectionSequence) + if len(selected_models) > 1: + ordered_models = self.askSelectedModelsOrder(selected_models) + if ordered_models is None: + return + selected_models = ordered_models + + self.selectedModel = selected_models + self.cancel = False + self.close() + return + + selected_items = self.listBox.selectedItems() + if not selected_items: + return + + selected_models = [item.text() for item in selected_items] + if len(selected_models) > 1: + ordered_models = self.askSelectedModelsOrder(selected_models) + if ordered_models is None: + return + self.selectedModel = ordered_models + self.cancel = False + self.close() + return + + item = selected_items[0] model = item.text() if model == 'Add custom model...': - modelFilePath = addCustomModelMessages(self) - if modelFilePath is None: + modelName = self._runAddCustomModelWorkflow() + if modelName is None: return - myutils.store_custom_model_path(modelFilePath) - modelName = os.path.basename(os.path.dirname(modelFilePath)) item = QListWidgetItem(modelName) self.listBox.addItem(item) self.listBox.setCurrentItem(item) elif model == 'Automatic thresholding': - self.selectedModel = 'thresholding' + self.selectedModel = model + self.cancel = False self.close() else: self.selectedModel = model + self.cancel = False self.close() - def cancel_cb(self, event): + def cancel_cb(self, event=None): self.cancel = True self.selectedModel = None self.close() @@ -5725,7 +5859,7 @@ def __init__(self, text, parent=None): mainLayout.addLayout(buttonsLayout) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) class startStopFramesDialog(QBaseDialog): def __init__( @@ -5760,7 +5894,7 @@ def __init__( okButton.clicked.connect(self.ok_cb) cancelButton.clicked.connect(self.close) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self): if self.selectFramesGroupbox.warningLabel.text(): @@ -6186,7 +6320,8 @@ def __init__( self.addAdditionalValues(additionalValues) self.setLayout(mainLayout) - self.setFont(font) + if font is not None: + self.setFont(font) # self.setModal(True) def showWhySizeTisGrayed(self): @@ -6723,9 +6858,7 @@ def __init__(self, mainWindow): seeHereLabel.setTextFormat(Qt.RichText) seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) seeHereLabel.setOpenExternalLinks(True) - font = QFont() - font.setPixelSize(12) - seeHereLabel.setFont(font) + seeHereLabel.setFont(fonts.font) seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;") paramsLayout.addWidget(seeHereLabel, row, 0, 1, 2) @@ -7112,7 +7245,7 @@ def __init__( layout.addLayout(buttonsLayout, 2, 1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def copyErrorMessage(self): cb = QApplication.clipboard() @@ -7444,7 +7577,7 @@ def setPosData(self): # self.img.setCurrentPosIndex(self.pos_i) # self.img.minMaxValuesMapper = self.mainWin.img1.minMaxValuesMapper self.origLab = self.posData.lab.copy() - self.origRp = skimage.measure.regionprops(self.origLab) + self.origRp = skimage.measure.regionprops(self.origLab) # why seperate rp here? self.origObjs = {obj.label:obj for obj in self.origRp} def valueChanged(self, value): @@ -7459,7 +7592,6 @@ def apply(self, origLab=None): if ccaAnnotRemoved: self.mainWin.updateAllImages() - if origLab is None: origLab = self.origLab.copy() @@ -8030,7 +8162,7 @@ def addAlphaScrollbar(self, channelName, imageItem, alphaScrollBar=None): if alphaScrollBar is None: alphaScrollBar = QScrollBar(Qt.Horizontal) label = QLabel(f'Alpha {channelName}') - label.setFont(font) + label.setFont(fonts.font) label.hide() alphaScrollBar.imageItem = imageItem alphaScrollBar.label = label @@ -8585,7 +8717,7 @@ def __init__(self, expPaths: dict, infoPaths: dict=None, parent=None): QAbstractItemView.SelectionMode.ExtendedSelection ) self.treeWidget.setHeaderHidden(True) - self.treeWidget.setFont(font) + self.treeWidget.setFont(fonts.font) for exp_path, positions in expPaths.items(): pathLevels = exp_path.split(os.sep) posFoldersInfo = None @@ -9522,7 +9654,7 @@ def __init__( entryWidget.setText(defaultTxt) if not self.allowText: entryWidget.textChanged[str].connect(self.onTextChanged) - entryWidget.setFont(font) + entryWidget.setFont(fonts.font) entryWidget.setAlignment(Qt.AlignCenter) self.entryWidget = entryWidget @@ -9530,7 +9662,7 @@ def __init__( if allowedValues is not None: notValidLabel = QLabel() notValidLabel.setStyleSheet('color: red') - notValidLabel.setFont(font) + notValidLabel.setFont(fonts.font) notValidLabel.setAlignment(Qt.AlignCenter) self.notValidLabel = notValidLabel @@ -10056,7 +10188,7 @@ def __init__( listBox.addItems(items) listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) listBox.setCurrentRow(0) - listBox.setFont(font) + listBox.setFont(fonts.font) topLayout.addWidget(listBox) listBox.hide() self.ListBox = listBox @@ -10078,7 +10210,7 @@ def __init__( if showInFileManagerPath is not None: showInFileManagerButton.clicked.connect(self.showInFileManager) - self.setFont(font) + self.setFont(fonts.font) def setSelectedItems(self, selectedItemsText): if self.multiPosButton.isChecked(): @@ -10976,9 +11108,7 @@ def __init__(self, filename, SizeZ, filenamesWithInfo, parent=None): self.setLayout(mainLayout) - font = QFont() - font.setPixelSize(12) - self.setFont(font) + self.setFont(fonts.font) # self.setModal(True) @@ -11839,7 +11969,7 @@ def __init__( printl(traceback.format_exc()) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) # self.setModal(True) def warningNoSegmRecipes(self): @@ -13046,7 +13176,7 @@ def __init__( metricsTreeWidget = QTreeWidget() metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) + metricsTreeWidget.setFont(fonts.font) self.metricsTreeWidget = metricsTreeWidget for chName in allChNames: @@ -13237,7 +13367,7 @@ def __init__( testButton.clicked.connect(self.test_cb) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.setStyleSheet(TREEWIDGET_STYLESHEET) @@ -13498,7 +13628,7 @@ def __init__(self, posDatas, parent=None): _spinBox = QSpinBox() _spinBox.setMaximum(214748364) _spinBox.setAlignment(Qt.AlignCenter) - _spinBox.setFont(font) + _spinBox.setFont(fonts.font) if posData.acdc_df is not None: _val = posData.acdc_df.index.get_level_values(0).max()+1 else: @@ -13617,7 +13747,7 @@ def __init__(self, acdcDfs, allChNames, parent=None, debug=False): for i, (acdc_df_endname, acdc_df) in enumerate(acdcDfs.items()): metricsTreeWidget = QTreeWidget() metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) + metricsTreeWidget.setFont(fonts.font) classified_metrics = measurements.classify_acdc_df_colnames( acdc_df, allChNames @@ -13805,7 +13935,7 @@ def __init__(self, acdcDfs, allChNames, parent=None, debug=False): # self.newColNameLineEdit.editingFinished.connect(self.equationChanged) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.setStyleSheet(TREEWIDGET_STYLESHEET) @@ -13993,7 +14123,7 @@ def __init__( row += 1 self.equationsList = widgets.TreeWidget() - self.equationsList.setFont(font) + self.equationsList.setFont(fonts.font) self.equationsList.setHeaderLabels(['Metric', 'Expression']) self.equationsList.setSelectionMode( QAbstractItemView.SelectionMode.ExtendedSelection) @@ -14302,7 +14432,7 @@ def __init__( mainLayout.addSpacing(20) mainLayout.addLayout(buttonsLayout) - self.setFont(font) + self.setFont(fonts.font) self.setLayout(mainLayout) def checkDuplicateShortcuts(self, text): @@ -14429,7 +14559,7 @@ def __init__(self, posData, parent=None): self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self): self.cancel = False @@ -14645,7 +14775,7 @@ def __init__( self.setLayout(self._layout) - # self.setFont(font) + # self.setFont(fonts.font) self.addButton.clicked.connect(self.addFeatureField) @@ -14928,7 +15058,7 @@ def __init__( mainLayout.addStretch() self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.unitCombobox.currentTextChanged.connect(self.updateLengthUnit) self.colorButton.clicked.disconnect() @@ -15047,7 +15177,7 @@ def __init__( self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def _warnNonUniqueCategories(self, category_1, category_2): txt = html_utils.paragraph(f""" @@ -15108,7 +15238,7 @@ def __init__( metricsTreeWidget = QTreeWidget() metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) + metricsTreeWidget.setFont(fonts.font) self.metricsTreeWidget = metricsTreeWidget for groupName, features in features_groups.items(): @@ -15153,7 +15283,7 @@ def __init__( metricsTreeWidget.itemDoubleClicked.connect(self.addFeatureName) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.setStyleSheet(TREEWIDGET_STYLESHEET) @@ -15439,7 +15569,7 @@ def __init__(self, parent=None, title='Input'): self.buttonsLayout = buttonsLayout - self.setFont(font) + self.setFont(fonts.font) self.setLayout(self.mainLayout) def askText(self, prompt, infoText='', allowEmpty=False): @@ -15909,7 +16039,7 @@ def __init__(self, parent=None, **properties): mainLayout.addStretch() self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.colorButton.clicked.disconnect() self.colorButton.clicked.connect(self.selectColor) @@ -19313,7 +19443,7 @@ def __init__( self.setAcceptDrops(True) - self.setFont(font) + self.setFont(fonts.font) def dragEnterEvent(self, event): event.acceptProposedAction() @@ -19353,12 +19483,59 @@ def expFolderToPosFoldernamesMapper(self): return expPathsPosFoldernamesMapper def ok_cb(self): - self.cancel = False + #verify all selected folders have Images folder: + faultyFolders = [] + for path, selected_pos in self.expFolderToPosFoldernamesMapper().items(): + if selected_pos == ['']: + images_path = myutils.get_images_folderpath(path) + if images_path is None or not os.path.exists(images_path): + faultyFolders.append(path) + + else: + for pos in selected_pos: + pos_path = os.path.join(path, pos) + images_path = myutils.get_images_folderpath(pos_path) + if images_path is None or not os.path.exists(images_path): + faultyFolders.append(pos_path) + + if faultyFolders: + self.warnNoAllValid(faultyFolders) + return + self.paths = self.pathsList() self.selectedExpFolderToPosFoldernamesMapper = ( self.expFolderToPosFoldernamesMapper() ) + if not self.selectedExpFolderToPosFoldernamesMapper: + self.warnEmptySelection() + return + self.cancel = False + self.close() + + def warnNoAllValid(self, faultyFolders=None): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(f""" + Some of the selected folders (see below) do not contain an Images folder.

+ Please, make sure to select Position folders, the Images folder inside Position folders, or any folder containing Position folders as sub-directories.

+ Thank you for your patience!

+ Selected folders: + + """) + msg.warning( + self, 'Some folders are not valid', txt + ) + + def warnEmptySelection(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + No folder was selected.

+ """) + msg.warning( + self, 'No folder selected', txt + ) def warnNoValidPathsFound(self, selected_path): msg = widgets.myMessageBox(wrapText=False) @@ -19427,11 +19604,11 @@ def addFolderPath(self, selected_path): myutils.addToRecentPaths(selected_path) folder_type = myutils.determine_folder_type(selected_path) - is_pos_folder, is_images_folder, folder_path = folder_type + is_pos_folder, is_images_folder, folder_path = folder_type if is_pos_folder: paths = [selected_path] elif is_images_folder: - paths = [os.path.dirname(selected_path)] + paths = [os.path.dirname(selected_path) if selected_path.endswith('Images') else selected_path] elif self.scanTree: print(f'Scanning selected folder "{selected_path}"...') exp_paths = path.get_posfolderpaths_walk(selected_path) diff --git a/cellacdc/core.py b/cellacdc/core.py index 80ae1df0e..d34733dc2 100755 --- a/cellacdc/core.py +++ b/cellacdc/core.py @@ -50,6 +50,7 @@ from . import measurements from . import favourite_func_metrics_csv_path from . import default_index_cols +from . import regionprops from ._types import ( ChannelsDict @@ -756,6 +757,19 @@ def get_obj_contours( only_longest_contour=True, local=False, ): + if obj is not None and obj_image is None and not local: + if all_external and not all: + obj_image = np.ascontiguousarray(obj.image, dtype=np.uint8) + contours, _ = cv2.findContours( + obj_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + min_y, min_x, _, _ = obj.bbox + return [np.squeeze(cont, axis=1)+[min_x, min_y] for cont in contours] + if all and hasattr(obj, 'contour_all'): + return obj.contour_all + if only_longest_contour and hasattr(obj, 'contour'): + return obj.contour + if all: retrieveMode = cv2.RETR_CCOMP else: @@ -1273,8 +1287,12 @@ def cca_df_to_acdc_df(cca_df, rp, acdc_df=None): IDs.append(obj.label) is_cell_dead_li.append(0) is_cell_excluded_li.append(0) - xx_centroid.append(int(obj.centroid[1])) - yy_centroid.append(int(obj.centroid[0])) + if isinstance(rp, regionprops.acdcRegionprops): + centroid = rp.get_centroid(obj.label, exact=True) + else: + centroid = obj.centroid + xx_centroid.append(int(centroid[1])) + yy_centroid.append(int(centroid[0])) acdc_df = pd.DataFrame({ 'Cell_ID': IDs, 'is_cell_dead': is_cell_dead_li, @@ -3036,16 +3054,25 @@ def split_connected_components(lab, rp=None, max_ID=None): return split_occured def split_along_convexity_defects( - ID, lab, max_ID, max_i=1, eps_percent=0.01 + ID, lab, max_ID, max_i=1, eps_percent=0.01, rp=None ): - lab_ID_bool = lab == ID + if rp is not None: + obj = rp.get_obj_from_ID(ID) + lab_ID_bool = np.zeros_like(lab[obj.slice], dtype=bool) + lab_ID_bool[obj.image] = True + else: + lab_ID_bool = lab == ID # First try separating by labelling lab_ID = lab_ID_bool.astype(int) rp_ID = skimage.measure.regionprops(lab_ID) split_occured = split_connected_components(lab_ID, rp=rp_ID, max_ID=max_ID) if split_occured: success = True - lab[lab_ID_bool] = lab_ID[lab_ID_bool] + if rp is not None: + lab[obj.slice][obj.image] = lab_ID[obj.image] + else: + lab[lab_ID_bool] = lab_ID[lab_ID_bool] + rp_ID = skimage.measure.regionprops(lab_ID) separateIDs = [obj.label for obj in rp_ID] return lab, success, separateIDs @@ -3093,7 +3120,10 @@ def split_along_convexity_defects( sep_bud_label = temp_sep_bud_lab sep_bud_label_mask = sep_bud_label != 0 # plt.imshow_tk(sep_bud_label, dots_coords=np.asarray(defects_points)) - lab[sep_bud_label_mask] = sep_bud_label[sep_bud_label_mask] + if rp is not None: + lab[obj.slice][sep_bud_label_mask] = sep_bud_label[sep_bud_label_mask] + else: + lab[sep_bud_label_mask] = sep_bud_label[sep_bud_label_mask] max_i += 1 success = True return lab, success, splittedIDs @@ -3194,53 +3224,119 @@ def insert_missing_objects( return segm_dst -def process_lab(task): - i, lab = task - # Assuming this function processes each lab independently - data_dict = {} - rp = skimage.measure.regionprops(lab) - IDs = [obj.label for obj in rp] - data_dict['IDs'] = IDs - data_dict['regionprops'] = rp - data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)} +### out of date +# def process_lab(task): +# i, lab = task +# # Assuming this function processes each lab independently +# data_dict = {} +# rp = skimage.measure.regionprops(lab) +# IDs = [obj.label for obj in rp] +# data_dict['IDs'] = IDs +# data_dict['regionprops'] = rp +# data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)} + +# return i, data_dict, IDs # Return index, data_dict, and IDs + +# def parallel_count_objects(posData, logger_func): +# benchmark = True +# #futile attempt to use multiprocessing to speed things up +# logger_func('Counting total number of segmented objects...') + +# allIDs = set() +# seg_data = posData.segm_data + +# # Initialize empty data dictionary to avoid recalculating each time +# tasks = [(i, lab) for i, lab in enumerate(seg_data)] + +# if benchmark: +# t0 = time.perf_counter() +# # Process in batches to optimize memory usage and control parallelism +# with ThreadPoolExecutor() as executor: +# futures = [executor.submit(process_lab, task) for task in tasks] + +# # Process results as they are completed +# for future in tqdm(as_completed(futures), total=len(futures), ncols=100): +# i, data_dict, IDs = future.result() +# posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable +# posData.allData_li[i]['IDs'] = data_dict['IDs'] +# posData.allData_li[i]['regionprops'] = data_dict['regionprops'] +# posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs'] +# allIDs.update(IDs) - return i, data_dict, IDs # Return index, data_dict, and IDs +# if benchmark: +# t1 = time.perf_counter() +# logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') -def parallel_count_objects(posData, logger_func): - benchmark = True - #futile attempt to use multiprocessing to speed things up - logger_func('Counting total number of segmented objects...') +# return allIDs, posData + +def check_file_time_proximity(file1, file2, max_seconds=300, logger_func=print): + if not os.path.isfile(file1): + return False - allIDs = set() - seg_data = posData.segm_data + if not os.path.isfile(file2): + return False + + mtime1 = os.path.getmtime(file1) + mtime2 = os.path.getmtime(file2) - # Initialize empty data dictionary to avoid recalculating each time - tasks = [(i, lab) for i, lab in enumerate(seg_data)] + if abs(mtime1 - mtime2) <= max_seconds: + return True + else: + logger_func(f'Warning: The files "{file1}" and "{file2}" were not saved within {max_seconds} seconds of each other.') + return False + +def verify_acdc_df_segm(posData: 'load.loadData', logger_func=print): + if posData.segmMetadata is None: + return None + segm_info = posData.segmMetadata[os.path.basename(posData.segm_npz_path)] + imgs_folder = posData.images_path + csv_name = segm_info['acdc_df_segm'] if 'acdc_df_segm' in segm_info else None + if csv_name is None: + return None + csv_filepath = os.path.join(imgs_folder, csv_name) + + # verify that that both files exist and are within the allowed time proximity + success = check_file_time_proximity( + posData.segm_npz_path, csv_filepath, max_seconds=120, logger_func=logger_func + ) + if not success: + return None + + return csv_filepath - if benchmark: - t0 = time.perf_counter() - # Process in batches to optimize memory usage and control parallelism - with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_lab, task) for task in tasks] +def verify_add_data_segm_proximity(posData: 'load.loadData', logger_func=print): + segm_path = posData.segm_npz_path + segm_filename = os.path.basename(segm_path).replace('.npz', '') + add_data_folder = os.path.join(posData.images_path, segm_filename) + + centroids_path = os.path.join(add_data_folder, 'centroids.pkl') + # IDs_path = os.path.join(add_data_folder, 'IDs.pkl') + centroids_IDs_exact_path = os.path.join(add_data_folder, 'centroids_IDs_exact.pkl') + # ID_to_idx_path = os.path.join(add_data_folder, 'ID_to_idx.pkl') + + ok = [True] * 2 + for idx, file in enumerate([centroids_path, centroids_IDs_exact_path]): + success = check_file_time_proximity( + segm_path, file, max_seconds=120, logger_func=logger_func + ) + if not success: + ok[idx] = False + + return { + 'centroids': centroids_path if ok[0] else None, + # 'IDs': IDs_path if ok[1] else None, + 'centroids_IDs_exact': centroids_IDs_exact_path if ok[1] else None, + # 'ID_to_idx': ID_to_idx_path if ok[3] else None, + } - # Process results as they are completed - for future in tqdm(as_completed(futures), total=len(futures), ncols=100): - i, data_dict, IDs = future.result() - posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable - posData.allData_li[i]['IDs'] = data_dict['IDs'] - posData.allData_li[i]['regionprops'] = data_dict['regionprops'] - posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs'] - allIDs.update(IDs) - if benchmark: - t1 = time.perf_counter() - logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') - - return allIDs, posData - -def count_objects(posData, logger_func): - benchmark = False - +# WARNING: this function has been attempted to be optimized by +# parallelization, loading data from last session +# The main bottleneck seams to be the rp creation (not even for example getting the IDs or centorids) +# Total time spend optimising here +# >5 hrs +# please update this if you try to optimize again +def count_objects_and_init_rps(posData: 'load.loadData', logger_func=print): allIDs = set() segm_data = posData.segm_data @@ -3250,23 +3346,14 @@ def count_objects(posData, logger_func): logger_func('Counting total number of segmented objects...') pbar = tqdm(total=len(segm_data), ncols=100) - if benchmark: - t0 = time.perf_counter() for i, lab in enumerate(segm_data): posData.allData_li[i] = myutils.get_empty_stored_data_dict() - rp = skimage.measure.regionprops(lab) - IDs = [obj.label for obj in rp] - posData.allData_li[i]['IDs'] = IDs + rp = regionprops.acdcRegionprops(lab) + IDs = rp.IDs_set posData.allData_li[i]['regionprops'] = rp - posData.allData_li[i]['IDs_idxs'] = { # IDs_idxs[obj.label] = idx - ID: idx for idx, ID in enumerate(IDs) - } allIDs.update(IDs) pbar.update() pbar.close() - if benchmark: - t1 = time.perf_counter() - logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms') return allIDs, posData def fix_sparse_directML(verbose=True): diff --git a/cellacdc/dataPrep.py b/cellacdc/dataPrep.py index 10b558445..9856e7f4b 100755 --- a/cellacdc/dataPrep.py +++ b/cellacdc/dataPrep.py @@ -44,6 +44,7 @@ from . import urls from . import io from .help import about +from . import fonts if os.name == 'nt': try: @@ -373,7 +374,7 @@ def gui_createToolBars(self): navigateToolbar.addAction(self.interpAction) self.ROIshapeComboBox = QComboBox() - self.ROIshapeComboBox.setFont(apps.font) + self.ROIshapeComboBox.setFont(fonts.font) self.ROIshapeComboBox.SizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToContents) self.ROIshapeComboBox.addItems([' 256x256 ']) ROIshapeLabel = QLabel(html_utils.paragraph( diff --git a/cellacdc/debugutils.py b/cellacdc/debugutils.py index b55d0eef1..40b94f717 100644 --- a/cellacdc/debugutils.py +++ b/cellacdc/debugutils.py @@ -1,9 +1,101 @@ import inspect, os, datetime, sys, traceback +import atexit +import linecache +from collections import defaultdict from . import cellacdc_path, myutils import gc import psutil +import time +import functools + +_LINE_BENCHMARK_TRACE_LIMIT = 10000 + +_LINE_BENCHMARK_STATS = defaultdict( + lambda: { + 'count': 0, + 'traced_count': 0, + 'untracked_count': 0, + 'total_time': 0.0, + 'min_time': float('inf'), + 'max_time': 0.0, + 'filename': None, + 'line_stats': defaultdict( + lambda: { + 'count': 0, + 'total_time': 0.0, + 'min_time': float('inf'), + 'max_time': 0.0, + } + ), + } +) + +def _get_benchmark_line_snippet(filename, lineno, max_chars=30): + if lineno == 'return': + return '' + if not filename: + return '' + + line = linecache.getline(filename, lineno).strip() + if not line: + return '' + + if len(line) <= max_chars: + # fill up to max_chars for better alignment + line = line.ljust(max_chars) + return line + return f'{line[:max_chars-3]}...' + +def _print_line_benchmark_session_stats(): + if not _LINE_BENCHMARK_STATS: + return + + print('\nLine benchmark session summary:') + for func_name, stats in sorted(_LINE_BENCHMARK_STATS.items()): + total_count = stats['count'] + traced_count = stats['traced_count'] + untracked_count = stats['untracked_count'] + if total_count == 0: + continue + + if traced_count: + mean_time = stats['total_time'] / traced_count + print( + f'{func_name}: n={total_count} | ' + f'traced={traced_count} | ' + f'untracked={untracked_count} | ' + f'mean={mean_time*1000:.3f} ms | ' + f'min={stats["min_time"]*1000:.3f} ms | ' + f'max={stats["max_time"]*1000:.3f} ms | ' + f'total={stats["total_time"]*1000:.3f} ms' + ) + else: + print( + f'{func_name}: n={total_count} | ' + f'traced=0 | ' + f'untracked={untracked_count}' + ) + + line_stats = stats['line_stats'] + top_lines = sorted( + line_stats.items(), + key=lambda item: item[1]['total_time'], + reverse=True + )[:10] + filename = stats['filename'] + for (start_line, end_line), line_stat in top_lines: + line_mean = line_stat['total_time'] / line_stat['count'] + line_snippet = _get_benchmark_line_snippet(filename, start_line) + print( + f' {line_snippet:<30} {start_line} -> {end_line}: ' + f'n={line_stat["count"]} | ' + f'mean={line_mean*1000:.3f} ms | ' + f'total={line_stat["total_time"]*1000:.3f} ms' + ) + +atexit.register(_print_line_benchmark_session_stats) def showRefGraph(object_str:str, debug:bool=True): """Save a reference graph of the given object type. @@ -206,3 +298,140 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100 # Example usage: # print_largest_classes("cellacdc", top_n=10) + +# Return a benchmark checkpoint with caller line information. +def return_timer_and_line(benchmarking=True): + if not benchmarking: + return None + timestamp = time.perf_counter() + line = inspect.currentframe().f_back.f_lineno # is super fast! + return (timestamp, line) + +def print_benchmarks(timers, benchmarking=True): + if not benchmarking: + return + checkpoints = [timer for timer in timers if timer is not None] + if len(checkpoints) < 2: + return + + print("Benchmarks:") + for (start_time, start_line), (end_time, end_line) in zip( + checkpoints, checkpoints[1:] + ): + duration = end_time - start_time + print( + f"Line {start_line} -> {end_line}: " + f"{duration:.6f} seconds" + ) + + total_duration = checkpoints[-1][0] - checkpoints[0][0] + print(f"Total: {total_duration:.6f} seconds") + +def line_benchmark(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + stats_key = f'{func.__module__}.{func.__qualname__}' + stats = _LINE_BENCHMARK_STATS[stats_key] + stats['count'] += 1 + + if stats['traced_count'] >= _LINE_BENCHMARK_TRACE_LIMIT: + stats['untracked_count'] += 1 + return func(*args, **kwargs) + + target_code = func.__code__ + filename = target_code.co_filename + checkpoints = [] + last_time = None + last_line = None + + def tracer(frame, event, arg): + nonlocal last_time, last_line + + if frame.f_code is not target_code: + return tracer + + now = time.perf_counter() + + if event == "call": + last_time = now + last_line = frame.f_lineno + return tracer + + if event == "line": + if last_time is not None and last_line is not None: + checkpoints.append((last_line, frame.f_lineno, now - last_time)) + last_time = now + last_line = frame.f_lineno + return tracer + + if event == "return": + if last_time is not None and last_line is not None: + checkpoints.append((last_line, "return", now - last_time)) + return tracer + + return tracer + + old_trace = sys.gettrace() + sys.settrace(tracer) + try: + result = func(*args, **kwargs) + finally: + sys.settrace(old_trace) + + total = sum(dt for _, _, dt in checkpoints) + stats['traced_count'] += 1 + stats['total_time'] += total + stats['min_time'] = min(stats['min_time'], total) + stats['max_time'] = max(stats['max_time'], total) + stats['filename'] = filename + + for start_line, end_line, dt in checkpoints: + line_stat = stats['line_stats'][(start_line, end_line)] + line_stat['count'] += 1 + line_stat['total_time'] += dt + line_stat['min_time'] = min(line_stat['min_time'], dt) + line_stat['max_time'] = max(line_stat['max_time'], dt) + + return result + + return wrapper + +def check_unused_methods(node_name): + import ast + from pathlib import Path + from collections import Counter + + file_path = Path("cellacdc/gui.py") + src = file_path.read_text(encoding="utf-8") + tree = ast.parse(src) + + gui_cls = None + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == node_name: + gui_cls = node + break + + if gui_cls is None: + raise RuntimeError(f"{node_name} class not found") + + methods = [] + refs = Counter() + + for node in gui_cls.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + methods.append(node.name) + + for n in ast.walk(node): + # Count any self.method reference (calls, signal connections, passing callbacks, etc.) + if isinstance(n, ast.Attribute) and isinstance(n.value, ast.Name) and n.value.id == "self": + refs[n.attr] += 1 + + # Also count guiWin.method(...) direct class-qualified calls inside class + if isinstance(n, ast.Attribute) and isinstance(n.value, ast.Name) and n.value.id == node_name: + refs[n.attr] += 1 + + unused_inside_guiwin = [m for m in methods if refs[m] == 0] + print(f"Total methods: {len(methods)}") + print(f"Potentially unused inside {node_name}: {len(unused_inside_guiwin)}") + for m in sorted(unused_inside_guiwin): + print(m) \ No newline at end of file diff --git a/cellacdc/docs/source/installation.rst b/cellacdc/docs/source/installation.rst index 901bde231..d5fc9fa2d 100644 --- a/cellacdc/docs/source/installation.rst +++ b/cellacdc/docs/source/installation.rst @@ -485,9 +485,26 @@ If you want to try out experimental features (and, if you have time, maybe repor Updating Cell-ACDC installed from source -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To update Cell-ACDC installed from source, open a terminal window, navigate to the Cell-ACDC folder with the command ``cd Cell_ACDC`` and run ``git pull``. -Since you installed with the ``-e`` flag, pulling with ``git`` is enough. \ No newline at end of file +Since you installed with the ``-e`` flag, pulling with ``git`` is enough. + + + +Compile Cython extensions +------------------------- + +Some of the functions in Cell-ACDC are implemented in Cython, which allows them +to run much faster. However, Cython functions need to be compiled before they +can be used. We provide pre-compiled versions of the Cython extensions for +Windows, macOS and Linux, but if you want to compile them yourself, +you can do so by activating your ``acdc`` environment, installing +Cython and setuptools with the command ``pip install Cython setuptools``, +and running the following command in the terminal from the Cell_ACDC folder: + +.. code-block:: + + python precompile_functions.py build_ext --inplace --build-temp build/temp \ No newline at end of file diff --git a/cellacdc/docs/source/tooltips.rst b/cellacdc/docs/source/tooltips.rst index 8a0404896..9857908af 100644 --- a/cellacdc/docs/source/tooltips.rst +++ b/cellacdc/docs/source/tooltips.rst @@ -447,8 +447,8 @@ Edit tools: Segmentation and tracking * Note: right-click on a background ROI to remove it. * HELP: Use this function if you need to set the background level specific for each object. Cell-ACDC will save the metrics `amount`, `concentration` and `corrected_mean` where the background correction will be performed by subtracting the mean of the signal in the background ROI (for each object). * **Delete everything outside segmented areas (** |delObjsOutSegmMaskAction| **):** Select a segmentation file and delete everything outside segmented area. -* **Hull contour (** |hullContToolButton| **"K"):** Right-click on a cell to replace it with its hull contour. Use it to fill cracks and holes. -* **Fill holes (** |fillHolesToolButton| **"F"):** Right-click on a cell to fill holes. +* **Hull contour (** |hullContToolButton| **"K"):** Right-click on a cell to replace it with its hull contour. Use it to fill cracks and holes. When working with 3D segmentation masks, the default behaviour is to replace only the viewed z-slice. To replace the entire object in 3D, hold "Shift" while right-clicking. +* **Fill holes (** |fillHolesToolButton| **"F"):** Right-click on a cell to fill holes. When working with 3D segmentation masks, the default behaviour is to fill only the viewed z-slice. To fill the entire object in 3D, hold "Shift" while right-clicking. * **Move object mask (** |moveLabelToolButton| **"P"):** Right-click drag and drop a labels to move it around. * **Expand/Shrink object mask (** |expandLabelToolButton| **"E"):** Leave mouse cursor on the label you want to expand/shrink and press arrow up/down on the keyboard to expand/shrink the mask. * **Edit ID (** |editIDbutton| **"N"):** Manually change ID of a cell by right-clicking on cell. When working with 3D segmentation masks, the default behaviour is to edit the ID in all z-slices. To edit the ID only on the viewed z-slice, hold "Shift" while right-clicking. diff --git a/cellacdc/fonts.py b/cellacdc/fonts.py new file mode 100644 index 000000000..31219da52 --- /dev/null +++ b/cellacdc/fonts.py @@ -0,0 +1,14 @@ +from . import GUI_INSTALLED + +if GUI_INSTALLED: + from qtpy.QtGui import QFont + + font = QFont() + font.setPixelSize(12) + italicFont = QFont() + italicFont.setPixelSize(12) + italicFont.setItalic(True) + +else: + font = None + italicFont = None \ No newline at end of file diff --git a/cellacdc/gui.py b/cellacdc/gui.py index 7247f462e..e09b86bbc 100755 --- a/cellacdc/gui.py +++ b/cellacdc/gui.py @@ -90,12 +90,13 @@ from . import is_mac from .trackers.CellACDC import CellACDC_tracker from .cca_functions import _calc_rot_vol -from .myutils import exec_time, setupLogger, ArgSpec +from .myutils import setupLogger, ArgSpec from .help import welcome, about from .trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import ( normal_division_lineage_tree)#, reorg_sister_cells_for_export) from . import debugutils - +from . import regionprops +from . import exec_time from .plot import imshow from . import gui_utils @@ -114,6 +115,11 @@ GREEN_HEX = _palettes.green() +RP_OPT_NUM_CELLS_MIN = 30 # th for trying to do local updates to regionprops, rp becomes slow for high num of cells +RP_OPT_PERC_CUTOUT_MAX = 0.3 # th for trying to do local updates to regionprops, + # if region which we have to update is too large too + # many cells are probably inside and its not worth + # local updating (since we actually need to call RP twice!) custom_annot_path = os.path.join(settings_folderpath, 'custom_annotations.json') shortcut_filepath = os.path.join(settings_folderpath, 'shortcuts.ini') @@ -537,7 +543,7 @@ def initGlobalAttr(self): ] self.lin_tree_df_colnames = self.lin_tree_df_int_cols + self.lin_tree_df_bool_col + self.lin_tree_col_checks - self.SegForLostIDsSettings = {} + self.SegForLostIDsSettings = {} def setWindowIcon(self, icon=None): if icon is None: @@ -1344,7 +1350,6 @@ def gui_createToolBars(self): self.hullContToolButton.action = editToolBar.addWidget(self.hullContToolButton) self.checkableButtons.append(self.hullContToolButton) self.checkableQButtonsGroup.addButton(self.hullContToolButton) - self.functionsNotTested3D.append(self.hullContToolButton) self.widgetsWithShortcut['Hull contour'] = self.hullContToolButton self.fillHolesToolButton = QToolButton(self) @@ -1356,7 +1361,6 @@ def gui_createToolBars(self): ) self.checkableButtons.append(self.fillHolesToolButton) self.checkableQButtonsGroup.addButton(self.fillHolesToolButton) - self.functionsNotTested3D.append(self.fillHolesToolButton) self.widgetsWithShortcut['Fill holes'] = self.fillHolesToolButton self.moveLabelToolButton = QToolButton(self) @@ -2034,6 +2038,7 @@ def gui_createControlsToolbar(self): brushEraserToolBar.addWidget(QLabel(' ')) self.brushAutoFillCheckbox = QCheckBox('Auto-fill holes') + self.brushAutoFillCheckbox.setTristate(False) self.brushAutoFillAction = brushEraserToolBar.addWidget( self.brushAutoFillCheckbox ) @@ -2708,7 +2713,6 @@ def gui_createActions(self): # Edit actions models = myutils.get_list_of_models() - models = [*models, 'local_seg'] # Add local_seg for SegForLostIDsAction self.segmActions = [] self.modelNames = [] self.acdcSegment_li = [] @@ -2765,7 +2769,7 @@ def gui_createActions(self): 'Track current frame with real-time tracker...', self ) self.repeatTrackingMenuAction.setDisabled(True) - self.repeatTrackingMenuAction.setShortcut('Shift+T') + self.repeatTrackingMenuAction.setShortcut('Ctrl+T') self.repeatTrackingVideoAction = QAction( 'Select a tracker and track multiple frames...', self @@ -3777,8 +3781,7 @@ def gui_createQuickSettingsWidgets(self): def showAllContoursToggled(self): if not self.isDataLoaded: return - - self.computeAllContours() + self.updateAllImages() def gui_createImg1Widgets(self): @@ -4692,7 +4695,8 @@ def _gui_createGraphicsItems(self): posData = self.data[self.pos_i] - allIDs, posData = core.count_objects(posData, self.logger.info) + allIDs, posData = core.count_objects_and_init_rps( + posData, self.logger.info) self.highLowResAction.setChecked(True) numItems = len(allIDs) @@ -4947,6 +4951,23 @@ def gui_initImg1BottomWidgets(self): self.zSliceSpinbox.hide() self.SizeZlabel.hide() + def rpCurr2D(self, frame_i=None, slice_i=None, depth_axis=None): + posData = self.data[self.pos_i] + if frame_i is None: + rp = posData.rp + else: + rp = posData.allData_li[frame_i]['regionprops'] + if not self.isSegm3D: + return rp + + if slice_i is None: + slice_i = self.zSliceScrollBar.sliderPosition() + if depth_axis is None: + depth_axis = self.switchPlaneCombobox.depthAxes() + + rp2D = rp.get_slice_rp(slice_i, slicing=depth_axis) + return rp2D + @exception_handler def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): modifiers = QGuiApplication.keyboardModifiers() @@ -4986,9 +5007,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) - Y, X = self.get_2Dlab(posData.lab).shape + Y, X = self.get_2Dlab(posData.lab, force_z=False).shape if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y: - ID = self.get_2Dlab(posData.lab)[ydata, xdata] + ID = self.get_2Dlab(posData.lab, force_z=False)[ydata, xdata] else: return @@ -5135,7 +5156,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): return else: ID = sepID_prompt.EntryID - y, x = posData.rp[posData.IDs_idxs[ID]].centroid[-2:] + + centroid = posData.rp.get_centroid(ID) + y, x = self.getObjCentroid(centroid) xdata, ydata = int(x), int(y) # Store undo state before modifying stuff @@ -5152,7 +5175,7 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): # self.set_2Dlab(lab2D) elif not shift: result = core.split_along_convexity_defects( - ID, self.get_2Dlab(posData.lab), max_ID + ID, self.get_2Dlab(posData.lab), max_ID, rp=posData.rp ) lab2D, success, splittedIDs = result self.set_2Dlab(lab2D) @@ -5190,7 +5213,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): self.storeManualSeparateDrawMode(manualSep.drawMode) # Update data (rp, etc) - self.update_rp() + bbox = self.update_rp_get_bbox(use_bbox=True, specific_IDs=ID) # use old ID to get bbox + specific_IDs = list(splittedIDs) + [ID] + self.update_rp(specific_IDs=specific_IDs, preloaded_bbox=bbox) # Repeat tracking self.trackSubsetIDs(splittedIDs) @@ -5233,13 +5258,24 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): if ID in posData.lab: # Store undo state before modifying stuff self.storeUndoRedoStates(False) - obj_idx = posData.IDs.index(ID) - obj = posData.rp[obj_idx] - objMask = self.getObjImage(obj.image, obj.bbox) - localFill = scipy.ndimage.binary_fill_holes(objMask) - posData.lab[self.getObjSlice(obj.slice)][localFill] = ID + if not shift and self.isSegm3D: + rp2D = self.rpCurr2D() + obj = rp2D.get_obj_from_ID(ID) + else: # shift hold or 2D from the getgo + obj = posData.rp.get_obj_from_ID(ID) + + localFill = scipy.ndimage.binary_fill_holes(obj.image) + + if not shift and self.isSegm3D: + curr_z = self.zSliceScrollBar.sliderPosition() + posData.lab[curr_z][obj.slice][localFill] = ID + else: + posData.lab[obj.slice][localFill] = ID - self.update_rp() + # here it is impossible that hole filling overwrites an ID which + # otuches border + + self.update_rp(use_bbox=True, specific_IDs=ID) self.updateAllImages() if not self.fillHolesToolButton.findChild(QAction).isChecked(): @@ -5272,13 +5308,24 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): if ID in posData.lab: # Store undo state before modifying stuff self.storeUndoRedoStates(False) - obj_idx = posData.IDs.index(ID) - obj = posData.rp[obj_idx] - objMask = self.getObjImage(obj.image, obj.bbox) - localHull = skimage.morphology.convex_hull_image(objMask) - posData.lab[self.getObjSlice(obj.slice)][localHull] = ID + if not shift and self.isSegm3D: + rp2D = self.rpCurr2D() + obj = rp2D.get_obj_from_ID(ID) + else: + obj = posData.rp.get_obj_from_ID(ID) + + localHull = skimage.morphology.convex_hull_image(obj.image) + if not shift and self.isSegm3D: + curr_z = self.zSliceScrollBar.sliderPosition() + hull_lab = posData.lab[curr_z][obj.slice] + else: + hull_lab = posData.lab[obj.slice] - self.update_rp() + IDs_overwritten = np.unique(hull_lab[localHull]) + IDs_overwritten = IDs_overwritten[IDs_overwritten != 0] + hull_lab[localHull] = ID + + self.update_rp(use_bbox=True, specific_IDs=IDs_overwritten) self.updateAllImages() if not self.hullContToolButton.findChild(QAction).isChecked(): @@ -5292,30 +5339,6 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): x, y = event.pos().x(), event.pos().y() self.startMovingLabel(x, y) - # Fill holes - elif right_click and self.fillHolesToolButton.isChecked(): - x, y = event.pos().x(), event.pos().y() - xdata, ydata = int(x), int(y) - ID = self.get_2Dlab(posData.lab)[ydata, xdata] - if ID == 0: - nearest_ID = core.nearest_nonzero_2D( - self.get_2Dlab(posData.lab), y, x - ) - clickedBkgrID = apps.QLineEditDialog( - title='Clicked on background', - msg='You clicked on the background.\n' - 'Enter here the ID that you want to ' - 'fill the holes of', - parent=self, allowedValues=posData.IDs, - defaultTxt=str(nearest_ID), - isInteger=True - ) - clickedBkgrID.exec_() - if clickedBkgrID.cancel: - return - else: - ID = clickedBkgrID.EntryID - # Merge IDs elif right_click and self.mergeIDsButton.isChecked(): x, y = event.pos().x(), event.pos().y() @@ -5344,9 +5367,8 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): self.storeUndoRedoStates(False) self.firstID = ID - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - yc, xc = self.getObjCentroid(obj.centroid) + centroid = posData.rp.get_centroid(ID) # maybe use 2D centroid here? + yc, xc = self.getObjCentroid(centroid) self.clickObjYc, self.clickObjXc = int(yc), int(xc) # Edit ID @@ -5373,8 +5395,8 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): else: ID = editID_prompt.EntryID - obj_idx = posData.IDs_idxs[ID] - y, x = posData.rp[obj_idx].centroid[-2:] + centroid = posData.rp.get_centroid(ID, exact=True) + y, x = self.getObjCentroid(centroid) xdata, ydata = int(x), int(y) posData.disableAutoActivateViewerWindow = True @@ -5402,7 +5424,7 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent): return if editID.assignNewID: - self.assignNewIDfromClickedID(ID, event) + self.assignNewIDfromClickedID(ID, event, shift=shift) return if not self.doNotAskAgainExistingID: @@ -5646,7 +5668,7 @@ def expandLabel(self, dilation=True): ID = self.hoverLabelID - obj = posData.rp[posData.IDs.index(ID)] + obj = posData.rp.get_obj_from_ID(ID) if reinitExpandingLab: # Store undo state before modifying stuff @@ -5679,7 +5701,9 @@ def expandLabel(self, dilation=True): expandedLab[self.currentLab2D>0] = 0 # Get coords of the dilated/eroded object - expandedObj = skimage.measure.regionprops(expandedLab)[0] + expandedObj = regionprops.acdcRegionprops( + expandedLab, precache_centroids=False)[0] + expandedObj_bbox = expandedObj.bbox expandedObjCoords = (expandedObj.coords[:,-2], expandedObj.coords[:,-1]) # Add the dilated/erored object @@ -5689,7 +5713,10 @@ def expandLabel(self, dilation=True): self.set_2Dlab(lab_2D) self.currentLab2D = lab_2D - self.update_rp() + preloaded_bbox = self.update_rp_get_bbox(custom_bbox=expandedObj_bbox) + self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=ID) + # we dont draw over other IDs so this is rare case where its fine + # to just have tight bbox and specific_IDs=ID if self.labelsGrad.showLabelsImgAction.isChecked(): self.img2.setImage(img=self.currentLab2D, autoLevels=False) @@ -5712,7 +5739,7 @@ def startMovingLabel(self, xPos, yPos): self.searchedIDitemLeft.setData([], []) self.movingID = ID self.prevMovePos = (xdata, ydata) - movingObj = posData.rp[posData.IDs.index(ID)] + movingObj = posData.rp.get_obj_from_ID(ID) self.movingObjCoords = movingObj.coords.copy() yy, xx = movingObj.coords[:,-2], movingObj.coords[:,-1] self.currentLab2D[yy, xx] = 0 @@ -5788,7 +5815,7 @@ def gui_mouseDragEventImg1(self, event): return posData = self.data[self.pos_i] - Y, X = self.get_2Dlab(posData.lab).shape + Y, X = self.get_2Dlab(posData.lab, force_z=False).shape xdata, ydata = int(x), int(y) if not myutils.is_in_bounds(xdata, ydata, X, Y): return @@ -5798,7 +5825,7 @@ def gui_mouseDragEventImg1(self, event): # Brush dragging mouse --> keep brushing elif self.isMouseDragImg1 and self.brushButton.isChecked(): - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) # t1 = time.perf_counter() @@ -5835,7 +5862,7 @@ def gui_mouseDragEventImg1(self, event): # t5 = time.perf_counter() - lab2D = self.get_2Dlab(posData.lab) + lab2D = self.get_2Dlab(posData.lab, force_z=False) brushMask = np.logical_and( lab2D[diskSlice] == posData.brushID, diskMask ) @@ -5860,7 +5887,7 @@ def gui_mouseDragEventImg1(self, event): # Eraser dragging mouse --> keep erasing elif self.isMouseDragImg1 and self.eraserButton.isChecked(): posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X) ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) @@ -5948,18 +5975,28 @@ def gui_mouseDragEventImg1(self, event): self.zoomRectItem.setSize((w, h)) # @exec_time - def fillHolesID(self, ID, sender='brush'): + def fillHolesID(self, ID, sender='brush', enabled=None): posData = self.data[self.pos_i] if sender == 'brush': - if not self.brushAutoFillCheckbox.isChecked(): + if enabled is None: + enabled = self.brushAutoFillCheckbox.isChecked() + + if not enabled: + return False + + if not self.brushButton.isChecked(): return False - lab2D = self.get_2Dlab(posData.lab) + lab2D = self.get_2Dlab(posData.lab, force_z=False) mask = lab2D == ID filledMask = scipy.ndimage.binary_fill_holes(mask) - lab2D[filledMask] = ID + newFilledMask = np.logical_and(filledMask, ~mask) + if not np.any(newFilledMask): + return False - self.set_2Dlab(lab2D) + # Apply only newly filled pixels to avoid rewriting the full 3D + # stack when editing in projection mode. + self.applyBrushMask(newFilledMask, ID) return True return False @@ -5984,11 +6021,9 @@ def highlightSearchedIDcheckBoxToggled(self, checked): self.highlightedID = self.getHighlightedID() if self.highlightedID == 0: return - objIdx = posData.IDs_idxs[self.highlightedID] - obj_idx = posData.IDs_idxs.get(self.highlightedID) - if obj_idx is None: + obj = posData.rp.get_obj_from_ID(self.highlightedID) + if obj is None: return - obj = posData.rp[objIdx] self.goToZsliceSearchedID(obj) def setHighlightID(self, doHighlight): @@ -6008,13 +6043,12 @@ def propsWidgetIDvalueChanged(self, ID): return propsQGBox = self.guiTabControl.propsQGBox - obj_idx = posData.IDs_idxs.get(ID) - if obj_idx is None: + obj = posData.rp.get_obj_from_ID(ID) + if obj is None: s = f'Object ID {int(ID):d} does not exist' propsQGBox.notExistingIDLabel.setText(s) return - obj = posData.rp[obj_idx] self.goToZsliceSearchedID(obj) self.updatePropsWidget(int(ID)) @@ -6039,7 +6073,7 @@ def updatePropsWidget(self, ID, fromHover=False): return if posData.rp is None: - self.update_rp() + self.update_rp() # IDK when can this happen? if not posData.IDs: # empty segmentation mask @@ -6051,8 +6085,8 @@ def updatePropsWidget(self, ID, fromHover=False): propsQGBox = self.guiTabControl.propsQGBox - obj_idx = posData.IDs_idxs.get(ID) - if obj_idx is None: + obj = posData.rp.get_obj_from_ID(ID) + if obj is None: s = f'Object ID {int(ID):d} does not exist' propsQGBox.notExistingIDLabel.setText(s) return @@ -6068,8 +6102,6 @@ def updatePropsWidget(self, ID, fromHover=False): if doHighlight: self.highlightSearchedID(ID) - obj = posData.rp[obj_idx] - if self.isSegm3D: if self.zProjComboBox.currentText() == 'single z-slice': local_z = self.z_lab() - obj.bbox[0] @@ -6355,9 +6387,8 @@ def drawTempMothBudLine(self, event, posData): if ID == 0: self.BudMothTempLine.setData([x1, x2], [y1, y2]) else: - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) + centroid = posData.rp.get_centroid(ID) + y2, x2 = self.getObjCentroid(centroid) self.BudMothTempLine.setData([x1, x2], [y1, y2]) def drawTempMergeObjsLine(self, event, posData, modifiers): @@ -6370,9 +6401,8 @@ def drawTempMergeObjsLine(self, event, posData, modifiers): y1, x1 = self.clickObjYc, self.clickObjXc ID = self.get_2Dlab(posData.lab)[ydata, xdata] if ID != 0: - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) + centroid = posData.rp.get_centroid(ID) + y2, x2 = self.getObjCentroid(centroid) if modifier and ID > 0: self.mergeObjsTempLine.addPoint(x2, y2) @@ -6535,7 +6565,7 @@ def gui_setCursor(self, modifiers, event): def warnAddingPointWithExistingId(self, point_id, table_endname=''): posData = self.data[self.pos_i] - if not point_id in posData.IDs_idxs: + if not point_id in posData.IDs: return True msg = widgets.myMessageBox(wrapText=False) @@ -6695,7 +6725,7 @@ def gui_mouseDragEventImg2(self, event): if mode == 'Viewer': return - Y, X = self.get_2Dlab(posData.lab).shape + Y, X = self.get_2Dlab(posData.lab, force_z=False).shape x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) if not myutils.is_in_bounds(xdata, ydata, X, Y): @@ -6704,7 +6734,7 @@ def gui_mouseDragEventImg2(self, event): # Eraser dragging mouse --> keep erasing if self.isMouseDragImg2 and self.eraserButton.isChecked(): posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) Y, X = lab_2D.shape x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) @@ -6735,7 +6765,7 @@ def gui_mouseDragEventImg2(self, event): # Brush paint dragging mouse --> keep painting if self.isMouseDragImg2 and self.brushButton.isChecked(): posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) Y, X = lab_2D.shape x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) @@ -6775,7 +6805,7 @@ def gui_mouseReleaseEventImg2(self, event): if mode == 'Viewer': return - Y, X = self.get_2Dlab(posData.lab).shape + Y, X = self.get_2Dlab(posData.lab, force_z=False).shape try: x, y = event.pos().x(), event.pos().y() except Exception as e: @@ -6792,7 +6822,7 @@ def gui_mouseReleaseEventImg2(self, event): self.isMovingLabel = False # Update data (rp, etc) - self.update_rp() + self.update_rp() # IDK can I do optimization here? # Repeat tracking self.tracking(enforce=True, assign_unique_new_IDs=False) @@ -6806,7 +6836,7 @@ def gui_mouseReleaseEventImg2(self, event): elif self.mergeIDsButton.isChecked(): x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) - lab2D = self.get_2Dlab(posData.lab) + lab2D = self.get_2Dlab(posData.lab, force_z=False) ID = lab2D[ydata, xdata] if ID == 0: nearest_ID = core.nearest_nonzero_2D( @@ -6826,31 +6856,33 @@ def gui_mouseReleaseEventImg2(self, event): return else: ID = mergeID_prompt.EntryID - obj_idx = posData.IDs_idxs[ID] - obj = posData.rp[obj_idx] - y2, x2 = self.getObjCentroid(obj.centroid) - self.mergeObjsTempLine.addPoint(x2, y2) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) xx, yy = self.mergeObjsTempLine.getData() IDs_to_merge = lab2D[yy.astype(int), xx.astype(int)] for ID in IDs_to_merge: if ID == 0: continue - posData.lab[posData.lab==ID] = self.firstID + obj = posData.rp.get_obj_from_ID(ID) + + posData.lab[obj.slice][obj.image] = self.firstID self.mergeObjsTempLine.setData([], []) self.clickObjYc, self.clickObjXc = None, None - - # Update data (rp, etc) - self.update_rp() - + + bbox = self.update_rp_get_bbox(specific_IDs=IDs_to_merge,use_bbox=True) # use old IDs to get bbox + specific_IDs = list(IDs_to_merge) + [self.firstID] + self.update_rp(specific_IDs=specific_IDs,preloaded_bbox=bbox) # update with new IDs ask_back_prop = True if posData.frame_i == 0: ask_back_prop = False prev_IDs = [] else: - prev_IDs = posData.allData_li[posData.frame_i-1]['IDs'] + prev_IDs = ( + posData.allData_li[posData.frame_i-1]['regionprops'].IDs) if all(ID not in prev_IDs for ID in IDs_to_merge): ask_back_prop = False @@ -6889,7 +6921,7 @@ def gui_mouseReleaseEventImg1(self, event): if mode == 'Viewer': return - Y, X = self.get_2Dlab(posData.lab).shape + Y, X = self.get_2Dlab(posData.lab, force_z=False).shape x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) if not myutils.is_in_bounds(xdata, ydata, X, Y): @@ -6919,8 +6951,9 @@ def gui_mouseReleaseEventImg1(self, event): if self.isRightClickDragImg1 and self.curvToolButton.isChecked(): self.isRightClickDragImg1 = False try: - self.curvToolSplineToObj(isRightClick=True) - self.update_rp() + mask, returnID = self.curvToolSplineToObj(isRightClick=True) + if mask is not None: + self.update_rp() # how can I optimize this? I think not possible tbh if self.autoIDcheckbox.isChecked(): self.trackManuallyAddedObject(posData.brushID, True) if self.isSnapshot: @@ -6940,13 +6973,18 @@ def gui_mouseReleaseEventImg1(self, event): self.isMouseDragImg1 = False self.clearTempBrushImage() + + erasedIDs = [ID for ID in self.erasedIDs if ID != 0] # Update data (rp, etc) - self.update_rp() + self.update_rp( + use_curr_view=True, + specific_IDs=erasedIDs or None, + ) # only visible stuff can be deleted doUpdateImages = self.checkWarnDeletedIDwithEraser() - if doUpdateImages: + if not doUpdateImages: self.updateAllImages() # Brush button mouse release @@ -6967,7 +7005,8 @@ def gui_mouseReleaseEventImg1(self, event): posData.lab[self.flood_mask] = posData.brushID # Update data (rp, etc) - self.update_rp() + # only visible stuff can be added, plus doesnt draw over eixisting + self.update_rp(use_curr_view=True, specific_IDs=posData.brushID) # Repeat tracking self.trackManuallyAddedObject(posData.brushID, self.isNewID) @@ -7048,7 +7087,7 @@ def gui_mouseReleaseEventImg1(self, event): self.isMovingLabel = False # Update data (rp, etc) - self.update_rp() + self.update_rp(use_curr_view=True) # only visible stuff can be moved # Repeat tracking self.tracking(enforce=True, assign_unique_new_IDs=False) @@ -7083,9 +7122,9 @@ def gui_mouseReleaseEventImg1(self, event): return else: ID = mothID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) if self.isSnapshot: # Store undo state before modifying stuff @@ -7116,11 +7155,9 @@ def gui_mouseReleaseEventImg1(self, event): # on a mother budID = self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud] new_mothID = self.get_2Dlab(posData.lab)[ydata, xdata] - bud_obj_idx = posData.IDs.index(budID) - new_moth_obj_idx = posData.IDs.index(new_mothID) - rp_budID = posData.rp[bud_obj_idx] - rp_new_mothID = posData.rp[new_moth_obj_idx] - if rp_budID.area >= rp_new_mothID.area: + bug_obj = posData.rp.get_obj_from_ID(budID) + new_mother_obj = posData.rp.get_obj_from_ID(new_mothID) + if bug_obj.area >= new_mother_obj.area: self.assignBudMothButton.setChecked(False) msg = widgets.myMessageBox() txt = ( @@ -7576,7 +7613,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): if left_click and canBrush: x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) Y, X = lab_2D.shape # Store undo state before modifying stuff @@ -7614,7 +7651,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): self.setImageImg2(updateLookuptable=False) how = self.drawIDsContComboBox.currentText() - lab2D = self.get_2Dlab(posData.lab) + lab2D = self.get_2Dlab(posData.lab, force_z=False) self.globalBrushMask = np.zeros(lab2D.shape, dtype=bool) brushMask = localLab == posData.brushID brushMask = np.logical_and(brushMask, diskMask) @@ -7627,7 +7664,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): elif left_click and canErase: x, y = event.pos().x(), event.pos().y() xdata, ydata = int(x), int(y) - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) Y, X = lab_2D.shape # Store undo state before modifying stuff @@ -7859,7 +7896,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): elif right_click and copyContourON: hoverLostID = self.ax1_lostObjScatterItem.hoverLostID self.copyLostObjectMask(hoverLostID) - self.update_rp() + self.update_rp(use_curr_view=True) # only visible self.updateAllImages() self.store_data() @@ -7909,7 +7946,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): if closeSpline: self.splineHoverON = False self.curvToolSplineToObj() - self.update_rp() + self.update_rp() # dont think I can optimize this if self.autoIDcheckbox.isChecked(): self.trackManuallyAddedObject(posData.brushID, True) if self.isSnapshot: @@ -7986,21 +8023,28 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): posData = self.data[self.pos_i] currentIDs = posData.IDs.copy() if manualTrackID in currentIDs: - tempID = max(currentIDs) + 1 - posData.lab[posData.lab == clickedID] = tempID - posData.lab[posData.lab == manualTrackID] = clickedID - posData.lab[posData.lab == tempID] = manualTrackID + clicked_obj = posData.rp.get_obj_from_ID(clickedID) + manual_track_obj = posData.rp.get_obj_from_ID(manualTrackID) + posData.lab[clicked_obj.slice][clicked_obj.image] = manualTrackID + posData.lab[manual_track_obj.slice][manual_track_obj.image] = ( + clickedID) self.manualTrackingToolbar.showWarning( f'The ID {manualTrackID} already exists --> ' f'ID {manualTrackID} has been swapped with {clickedID}' ) + assignments = {clickedID: manualTrackID, + manualTrackID: clickedID} else: - posData.lab[posData.lab == clickedID] = manualTrackID + clicked_obj = posData.rp.get_obj_from_ID(clickedID) + posData.lab[clicked_obj.slice][clicked_obj.image] = manualTrackID self.manualTrackingToolbar.showInfo( f'ID {clickedID} changed to {manualTrackID}.' ) + assignments = {clickedID: manualTrackID} - self.update_rp() + # only ID change, so use assignments + # not 3D ready yet? Otherwise I must set assignments to None + self.update_rp(assignments=assignments) self.updateAllImages() elif right_click and manualBackgroundON: @@ -8064,9 +8108,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): return else: ID = divID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) if not self.isSnapshot: # Store undo state before modifying stuff @@ -8109,8 +8153,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): ID = budID_prompt.EntryID obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) relationship = posData.cca_df.at[ID, 'relationship'] is_history_known = posData.cca_df.at[ID, 'is_history_known'] @@ -8156,9 +8201,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): return else: ID = unknownID_prompt.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) self.annotateIsHistoryKnown(ID) if not self.setIsHistoryKnownButton.findChild(QAction).isChecked(): @@ -8185,9 +8230,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent): return else: ID = clickedBkgrDialog.EntryID - obj_idx = posData.IDs.index(ID) - y, x = posData.rp[obj_idx].centroid - xdata, ydata = int(x), int(y) + centroid = posData.rp.get_centroid(ID) + ydata, xdata = self.getObjCentroid(centroid) + ydata, xdata = int(ydata), int(xdata) button = self.doCustomAnnotation(ID) if button is None: @@ -8406,103 +8451,325 @@ def gui_addCreatedAxesItems(self): self.ax1.exportMaskImageItem = self.exportMaskImageItem def SegForLostIDsSetSettings(self): + posData = self.data[self.pos_i] + displayed_input_label = 'Displayed image' + + recipe_json_path = os.path.join( + settings_folderpath, 'segmentation_for_lostIDs_recipe.json' + ) try: - prev_model = str(self.df_settings.at['SegForLostIDsModel', 'value']) + prev_models = [ + model.strip() for model in str( + self.df_settings.at['SegForLostIDsModel', 'value'] + ).split(',') if model.strip() + ] except KeyError: - prev_model = None - win = apps.QDialogSelectModel(parent=self, customFirst=prev_model) + prev_models = [] + + has_last_recipe = bool(prev_models) and os.path.exists(recipe_json_path) + seg_for_lost_ids_info = ( + 'Segmentation for lost IDs settings

' + 'Use this dialog to define the segmentation workflow used for ' + 'resegmenting local neighborhood lost IDs. Other already segmented cells are filled ' + 'with background, which makes even dimm cells seem bright after ' + 'rescaling before resegmentation. This is especially usefull for ' + 'cells which have varying intensities over time, like FUCCI cells.

' + 'How model selection works
' + '- You can select one model or multiple models.
' + '- In multi-selection mode you can include the same model multiple ' + 'times (for example, model A, then model B, then model A again).
' + '- After confirming, you can reorder the selected models. The order ' + 'is the execution order. ' + '- You then will be asked to set model parameters in the order selected.

' + ' - Pay special attention to the additional "Settings for local ' + 'segmentation" section, here you can for example select any image as input.

' + 'Load last selection...
' + 'Restores only the list of selected model names (the recipe order ' + 'selection), then lets you continue configuring parameters.

' + 'Load last recipe...
' + 'Loads the complete saved recipe from disk, including model order and ' + 'all model-specific settings (when available).

' + 'Add custom model...
' + 'Lets you register an additional local custom model and include it in ' + 'the sequence.

' + 'Tip: if you want to run the same model twice with different ' + 'parameters, add it twice and configure each step independently.' + ) + win = apps.QDialogSelectModel( + parent=self, + allowMultiSelection=True, + lastSelection=prev_models, + addSelectLastSelectionButton=bool(prev_models), + addSelectLastRecipeButton=has_last_recipe, + custom_title='Select model(s) for segmentation of lost IDs', + info_label=seg_for_lost_ids_info, + ) win.exec_() if win.cancel: self.logger.info('Seg for lost IDs cancelled.') return - base_model_name = win.selectedModel - if base_model_name: - self.df_settings.at['SegForLostIDsModel', 'value'] = base_model_name + if getattr(win, 'loadLastRecipe', False): + self.logger.info('Loading last segmentation recipe for lost IDs...') + try: + with open(recipe_json_path, 'r') as f: + recipe_data = json.load(f) + model_settings = [] + for entry in recipe_data['models']: + model_settings.append({ + 'win': None, + 'init_kwargs_new': entry['init_kwargs_new'], + 'args_new': entry['args_new'], + 'base_model_name': entry['base_model_name'], + 'init_kwargs': entry.get('init_kwargs', {}), + 'model_kwargs': entry.get('model_kwargs', {}), + 'preproc_recipe': entry.get('preproc_recipe', None), + 'applyPostProcessing': entry.get('applyPostProcessing', False), + 'standardPostProcessKwargs': entry.get('standardPostProcessKwargs', {}), + 'customPostProcessFeatures': entry.get('customPostProcessFeatures', None), + 'customPostProcessGroupedFeatures': entry.get('customPostProcessGroupedFeatures', None), + }) + self.SegForLostIDsSettings = {'models_settings': model_settings} + # Restore model names in settings + restored_models = [ + 'Automatic thresholding' + if m['base_model_name'] == 'thresholding' + else m['base_model_name'] + for m in model_settings + ] + self.df_settings.at['SegForLostIDsModel', 'value'] = ( + ', '.join(restored_models) + ) + self.df_settings.to_csv(self.settings_csv_path) + self.logger.info('Last segmentation recipe loaded successfully.') + except Exception as e: + self.logger.error(f'Failed to load last recipe: {e}') + return + + selected_models = win.selectedModel + if isinstance(selected_models, str): + selected_models = [selected_models] + + if not selected_models: + self.logger.info('Seg for lost IDs cancelled.') + return + + if selected_models: + self.df_settings.at['SegForLostIDsModel', 'value'] = ( + ', '.join(selected_models) + ) self.df_settings.to_csv(self.settings_csv_path) - model_name = 'local_seg' + all_extra_params = [ + 'image_channel_name', + 'overlap_threshold', + 'padding', + 'size_perc_diff', + 'distance_filler_growth', + 'allow_only_tracked_cells', + ] + extra_types = { + 'overlap_threshold': float, + 'padding': float, + 'size_perc_diff': float, + 'distance_filler_growth': float, + 'allow_only_tracked_cells': bool, + 'image_channel_name': str, + } + extra_defaults = { + 'overlap_threshold': 0.5, + 'padding': 0.8, + 'size_perc_diff': 0.3, + 'distance_filler_growth': 1., + 'allow_only_tracked_cells': False, + 'image_channel_name': displayed_input_label, + } + extra_desc = { + 'overlap_threshold': ( + 'Overlap threshold with other already segemented cells ' + 'over which newly segmented cells are discarded' + ), + 'padding': ( + 'Padding of the box used for new segmentation around the ' + 'segmentation from the previous frame' + ), + 'size_perc_diff': ( + 'Relative size difference acceptable compared to previous ' + 'frames' + ), + 'distance_filler_growth': ( + 'Cells which are already segmented are filled with random ' + 'noise sampled from background to ensure that they do not ' + 'get segmented again. This parameter controls the additional ' + 'padding around the already segmented cells.' + ), + 'allow_only_tracked_cells': ( + 'If no new cell IDs should be permitted ' + '(based on real time tracking)' + ), + 'image_channel_name': ( + 'Image channel used as model input. ' + 'Select "Displayed image" to use exactly what is currently ' + 'shown in the viewer, or select a specific fluorescence ' + 'channel.' + ), + } - idx = self.modelNames.index(model_name) - acdcSegment = self.acdcSegment_li[idx] + model_settings = [] + remembered_extra_args = {} + for model_idx, selected_model_name in enumerate(selected_models): + model_name = selected_model_name + if model_name == 'Automatic thresholding': + model_name = 'thresholding' + try: + if selected_model_name in self.modelNames: + idx = self.modelNames.index(selected_model_name) + acdcSegment = self.acdcSegment_li[idx] + if acdcSegment is None: + self.logger.info(f'Importing {model_name}...') + acdcSegment = myutils.import_segment_module(model_name) + self.acdcSegment_li[idx] = acdcSegment + else: + self.logger.info(f'Importing {model_name}...') + acdcSegment = myutils.import_segment_module(model_name) + except (ImportError, KeyError) as e: + self.logger.error(f'Error importing {model_name}: {e}') + return - try: - if acdcSegment is None or base_model_name != self.local_seg_base_model_name: - self.logger.info(f'Importing {base_model_name}...') - acdcSegment = myutils.import_segment_module(base_model_name) - self.acdcSegment_li[idx] = acdcSegment - self.local_seg_base_model_name = base_model_name - except (IndexError, ImportError, KeyError) as e: - self.logger.error(f'Error importing {base_model_name}: {e}') - return - - extra_params = ['overlap_threshold', - 'padding', - 'size_perc_diff', - 'distance_filler_growth', - 'max_iterations', - 'allow_only_tracked_cells'] - - extra_types = [float, float, float, float, int, bool] - - extra_defaults = [0.5, 0.8, 0.3, 1., 2, False] - - extra_desc = ['Overlap threshold with other already segemented cells over which newly segmented cells are discarded', - 'Padding of the box used for new segmentation around the segmentation from the previous frame', - 'Relative size difference acceptable compared to previous frames', - """Cells which are already segmented are filled with random noise sampled from background - to ensure that they don't get segmented again. - This parameter controls the additional padding around the already segmented cells.""", - """The algorithm will try and segment the maximum amount - of cells in the image by running the model several - times and filling new found cells with background noise. - How many of these iterations should be run?""", - "If no new cell IDs should be permitted (based on real time tracking)"] - - extra_ArgSpec = [] - for i, param in enumerate(extra_params): - param = ArgSpec(name=param, - default=extra_defaults[i], - type=extra_types[i], - desc=extra_desc[i], - docstring='') - - extra_ArgSpec.append(param) + extra_params = all_extra_params - init_params, segment_params = myutils.getModelArgSpec(acdcSegment) - segment_params = [arg for arg in segment_params if arg[0] != 'diameter'] - - extraParamsTitle = 'Settings for local segmentation' - win = self.initSegmModelParams( - base_model_name, acdcSegment, init_params, segment_params, - extraParams=extra_ArgSpec, extraParamsTitle=extraParamsTitle, - initLastParams=True, ini_filename='segmentation_for_lostIDs.ini', - ) + available_fluo_channels = [ + ch for ch in posData.chNames if ch != self.user_ch_name + ] + channel_options = [displayed_input_label, *available_fluo_channels] - if win is None: - self.logger.info('Segmentation for lost IDs cancelled.') - return + class _SegForLostIDsInputChannelType: + values = channel_options - init_kwargs_new = {} - args_new = {} - for key, val in win.init_kwargs.items(): - if key in extra_params: - args_new[key] = val - else: - init_kwargs_new[key] = val + extra_types['image_channel_name'] = _SegForLostIDsInputChannelType + + extra_ArgSpec = [] + for param in extra_params: + param_arg = ArgSpec( + name=param, + default=extra_defaults[param], + type=extra_types[param], + desc=extra_desc[param], + docstring='' + ) + extra_ArgSpec.append(param_arg) + + init_params, segment_params = myutils.getModelArgSpec(acdcSegment) + segment_params = [ + arg for arg in segment_params if arg[0] != 'diameter' + ] + + initLastParams = True + if model_name == 'thresholding': + win_thresh = apps.QDialogAutomaticThresholding( + parent=self, isSegm3D=self.isSegm3D + ) + win_thresh.exec_() + if win_thresh.cancel: + self.logger.info('Segmentation for lost IDs cancelled.') + return + self.model_kwargs = win_thresh.segment_kwargs + thresh_method = self.model_kwargs['threshold_method'] + gauss_sigma = self.model_kwargs['gauss_sigma'] + segment_params = myutils.insertModelArgSpec( + segment_params, 'threshold_method', thresh_method + ) + segment_params = myutils.insertModelArgSpec( + segment_params, 'gauss_sigma', gauss_sigma + ) + initLastParams = False + + extraParamsTitle = ( + f'Settings for local segmentation ' + f'({model_idx + 1}/{len(selected_models)})' + ) + win = self.initSegmModelParams( + model_name, acdcSegment, init_params, segment_params, + extraParams=extra_ArgSpec, + extraParamsTitle=extraParamsTitle, + initLastParams=initLastParams, + ini_filename='segmentation_for_lostIDs.ini', + ) - for key, val in win.extra_kwargs.items(): - if key in extra_params: - args_new[key] = val + if win is None: + self.logger.info('Segmentation for lost IDs cancelled.') + return + + init_kwargs_new = {} + args_new = {} + for key, val in win.init_kwargs.items(): + if key in extra_params: + args_new[key] = val + else: + init_kwargs_new[key] = val + + for key, val in win.extra_kwargs.items(): + if key in extra_params: + if key == 'image_channel_name': + init_kwargs_new[key] = val + else: + args_new[key] = val + + for key, val in remembered_extra_args.items(): + if key == 'image_channel_name': + init_kwargs_new.setdefault(key, val) + continue + args_new.setdefault(key, val) + + if model_idx == 0: + remembered_extra_args = args_new.copy() + remembered_extra_args['image_channel_name'] = ( + init_kwargs_new.get('image_channel_name', displayed_input_label) + ) + + model_settings.append({ + 'win': win, + 'init_kwargs_new': init_kwargs_new, + 'args_new': args_new, + 'base_model_name': model_name, + 'init_kwargs': dict(win.init_kwargs), + 'model_kwargs': dict(win.model_kwargs), + 'preproc_recipe': win.preproc_recipe, + 'applyPostProcessing': win.applyPostProcessing, + 'standardPostProcessKwargs': win.standardPostProcessKwargs, + 'customPostProcessFeatures': win.customPostProcessFeatures, + 'customPostProcessGroupedFeatures': win.customPostProcessGroupedFeatures, + }) self.SegForLostIDsSettings = { - 'win': win, - 'init_kwargs_new': init_kwargs_new, - 'args_new': args_new, - 'base_model_name': base_model_name, + 'models_settings': model_settings, } + # Persist recipe to disk so it survives across sessions + try: + recipe_data = { + 'models': [ + { + 'base_model_name': ms['base_model_name'], + 'init_kwargs_new': ms['init_kwargs_new'], + 'args_new': ms['args_new'], + 'init_kwargs': ms['init_kwargs'], + 'model_kwargs': ms['model_kwargs'], + 'preproc_recipe': ms['preproc_recipe'], + 'applyPostProcessing': ms['applyPostProcessing'], + 'standardPostProcessKwargs': ms['standardPostProcessKwargs'], + 'customPostProcessFeatures': ms['customPostProcessFeatures'], + 'customPostProcessGroupedFeatures': ms['customPostProcessGroupedFeatures'], + } + for ms in model_settings + ] + } + with open(recipe_json_path, 'w') as f: + json.dump(recipe_data, f, indent=2, default=str) + except Exception as e: + self.logger.warning(f'Could not save recipe to disk: {e}') + def segForLostIDsButtonClicked(self): self.setFrameNavigationDisabled(disable=True, why='Segmentation for lost IDs') @@ -8526,9 +8793,9 @@ def onSegForLostInit(self): self.SegForLostIDsSetSettings() self.SegForLostIDsWaitCond.wakeAll() - def SegForLostIDsWorkerAskInstallModel(self, model_name): - myutils.check_install_package(model_name) - self.SegForLostIDsWaitCond.wakeAll() + # def SegForLostIDsWorkerAskInstallModel(self, model_name): + # myutils.check_install_package(model_name) + # self.SegForLostIDsWaitCond.wakeAll() def startSegForLostIDsWorker(self): self.SegForLostIDsMutex = QMutex() @@ -8542,9 +8809,9 @@ def startSegForLostIDsWorker(self): # Connect the worker's signal to the main thread's slot self.SegForLostIDsWorker.sigAskInit.connect(self.onSegForLostInit) - self.SegForLostIDsWorker.sigAskInstallModel.connect( - self.SegForLostIDsWorkerAskInstallModel - ) + # self.SegForLostIDsWorker.sigAskInstallModel.connect( + # self.SegForLostIDsWorkerAskInstallModel + # ) self.SegForLostIDsWorker.sigshowImageDebug.connect( self.showImageDebug ) @@ -8553,8 +8820,13 @@ def startSegForLostIDsWorker(self): self.SegForLostIDsWorkerAskInstallGPU ) - self.SegForLostIDsWorker.sigStoreData.connect(self.onSigStoreDataSegForLostIDsWorker) - self.SegForLostIDsWorker.sigUpdateRP.connect(self.onSigUpdateRPSegForLostIDsWorker) + self.SegForLostIDsWorker.sigStoreData.connect( + self.onSigStoreDataSegForLostIDsWorker) + self.SegForLostIDsWorker.sigUpdateRP.connect( + self.onSigUpdateRPSegForLostIDsWorker) + self.SegForLostIDsWorker.sigGetSegForLostIDsInputImg.connect( + self.onSigGetInputImgSegForLostIDsWorker + ) # self.SegForLostIDsWorker.sigGetData.connect(self.onSigGetDataSegForLostIDsWorker) # self.SegForLostIDsWorker.sigGet2Dlab.connect(self.onSigGet2DlabSegForLostIDsWorker) # self.SegForLostIDsWorker.sigGetTrackedLostIDs.connect(self.onSigGetTrackedSegForLostIDsWorker) @@ -8619,6 +8891,31 @@ def onSigTrackManuallyAddedObjectSegForLostIDsWorker(self, added_IDs, isNewID, w self.trackManuallyAddedObject(added_IDs, isNewID, wl_update=wl_update, wl_track_og_curr=wl_track_og_curr) self.SegForLostIDsWaitCond.wakeAll() + def onSigGetInputImgSegForLostIDsWorker(self, image_channel_name): + displayed_input_label = 'Displayed image' + posData = self.data[self.pos_i] + + if ( + not image_channel_name + or image_channel_name == displayed_input_label + ): + img = self.getDisplayedImg1() + self.SegForLostIDsWorker.inputImgForSegForLostIDs = img + self.SegForLostIDsWaitCond.wakeAll() + return + + self.getChData(requ_ch={image_channel_name}) + + _, filename = self.getPathFromChName(image_channel_name, posData) + fluo_data = posData.fluo_data_dict.get(filename) + if posData.SizeT > 1: + fluo_img_data = fluo_data[posData.frame_i] + else: + fluo_img_data = fluo_data + + self.SegForLostIDsWorker.inputImgForSegForLostIDs = fluo_img_data + self.SegForLostIDsWaitCond.wakeAll() + def onSigStoreData( self, waitcond, pos_i=None, enforce=True, debug=False, @@ -8628,10 +8925,17 @@ def onSigStoreData( autosave=autosave, store_cca_df_copy=store_cca_df_copy) waitcond.wakeAll() - def onSigUpdateRP(self, waitcond, draw=True, debug=False, update_IDs=True, - wl_update=True, wl_track_og_curr=False): - self.update_rp(draw=draw, debug=debug, update_IDs=update_IDs, - wl_update=wl_update, wl_track_og_curr=wl_track_og_curr) + def onSigUpdateRP(self, waitcond, + draw=True, debug=False, # og stuff + assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same + specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR + wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff + ): + self.update_rp(draw=True, debug=False, # og stuff + assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same + specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR + wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff + ) waitcond.wakeAll() def onSigGetData(self, waitcond, debug=False): @@ -8640,7 +8944,7 @@ def onSigGetData(self, waitcond, debug=False): def SegForLostIDsWorkerFinished(self): self.updateAllImages() - self.update_rp() + self.update_rp() # will update when updating segoforlostIDs self.store_data(autosave=True) self.setFrameNavigationDisabled(disable=False, why='Segmentation for lost IDs') @@ -8649,8 +8953,15 @@ def SegForLostIDsWorkerFinished(self): self.progressWin.close() self.progressWin = None - def showImageDebug(self, img): - imshow(img) + def showImageDebug(self, display_info): + title = '' + img_titles = None + if isinstance(display_info, dict): + title = display_info.get('title', '') + img_titles = display_info.get('img_titles', None) + imgs = display_info.get('images', []) + imshow(*imgs, window_title=str(title), figure_title=str(title), + axis_titles=img_titles) def gui_raiseBottomLayoutContextMenu(self, event): try: @@ -8990,14 +9301,13 @@ def searchIDworkerCallback(self, posData, searchedID): for frame_i in range(len(posData.segm_data)): if frame_i >= len(posData.allData_li): break - lab = posData.allData_li[frame_i]['labels'] - if lab is None: - rp = skimage.measure.regionprops(posData.segm_data[frame_i]) - IDs = set([obj.label for obj in rp]) - else: - IDs = posData.allData_li[frame_i]['IDs'] - if searchedID in IDs: + rp = posData.allData_li[frame_i]['regionprops'] + if rp is None: + lab = posData.segm_data[frame_i] + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) + posData.allData_li[frame_i]['regionprops'] = rp + if searchedID in rp.IDs: frame_i_found = frame_i break @@ -9014,8 +9324,7 @@ def warnIDnotFound(self, searchedID): def goToObjectID(self, ID): posData = self.data[self.pos_i] - objIdx = posData.IDs_idxs[ID] - obj = posData.rp[objIdx] + obj = posData.rp.get_obj_from_ID(ID) self.goToZsliceSearchedID(obj) self.highlightSearchedID(ID) @@ -9026,8 +9335,7 @@ def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)): posData = self.data[self.pos_i] frame_i = posData.frame_i prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] - obj = prev_rp[prev_IDs_idxs[lostID]] + obj = prev_rp.get_obj_from_ID(lostID) self.goToZsliceSearchedID(obj) imageItem = self.getLostObjImageItem(0) @@ -9038,7 +9346,11 @@ def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)): self.lostObjContoursImage[:] = 0 contours = [] - obj_contours = self.getObjContours(obj, all_external=True) + obj_contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) contours.extend(obj_contours) self.addLostObjsToLostObjImage(obj, lostID) @@ -9050,8 +9362,7 @@ def goToAcceptedLostObjectID(self, acceptedLostID): posData = self.data[self.pos_i] frame_i = posData.frame_i prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] - obj = prev_rp[prev_IDs_idxs[acceptedLostID]] + obj = prev_rp.get_obj_from_ID(acceptedLostID) self.goToZsliceSearchedID(obj) self.updateLostTrackedContoursImage(tracked_lost_IDs=[acceptedLostID]) @@ -9633,7 +9944,8 @@ def applyEditID( self, clickedID, currentIDs, oldIDnewIDMapper, clicked_x, clicked_y, shift=False, doPropagateUnvisited=False ): posData = self.data[self.pos_i] - + rp = self.rpCurr2D() if (shift and self.isSegm3D) else posData.rp + use_3D_obj_centroid = not (shift and self.isSegm3D) # Ask to propagate change to all future visited frames key = 'Edit ID' askAction = self.askHowFutureFramesActions[key] @@ -9654,43 +9966,51 @@ def applyEditID( lab = posData.lab # Store undo state before modifying stuff + # no risk of merging IDs if we are working with rp and dont updaet in the middle... self.storeUndoRedoStates(UndoFutFrames) - maxID = max(posData.IDs, default=0) - for old_ID, new_ID in oldIDnewIDMapper: + # could this be chained??? If yes we have to "simplify" to least swops to since we keep RP stale + # oldIDnewIDMapper + assignments = {} + for old_ID, new_ID in oldIDnewIDMapper: if new_ID in currentIDs and not self.editIDmergeIDs: - tempID = maxID + 1 - lab[lab == old_ID] = maxID + 1 - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - maxID += 1 - - old_ID_idx = currentIDs.index(old_ID) - new_ID_idx = currentIDs.index(new_ID) - - # Append information for replicating the edit in tracking - # List of tuples (y, x, replacing ID) - objo = posData.rp[old_ID_idx] - yo, xo = self.getObjCentroid(objo.centroid) - objn = posData.rp[new_ID_idx] - yn, xn = self.getObjCentroid(objn.centroid) - if not math.isnan(yo) and not math.isnan(yn): - yn, xn = int(yn), int(xn) - posData.editID_info.append((yn, xn, new_ID)) - yo, xo = int(clicked_y), int(clicked_x) - posData.editID_info.append((yo, xo, old_ID)) + objo = rp.get_obj_from_ID(old_ID) + objn = rp.get_obj_from_ID(new_ID) + + # Relabel old_ID to new ID, save since rp is "stale" + slc_o = objo.slice + mask_o = objo.image + lab[slc_o][mask_o] = new_ID + + # Relabel new_ID to old_ID + slc_n = objn.slice + mask_n = objn.image + lab[slc_n][mask_n] = old_ID + + + # ¯\_(ツ)_/¯ + if use_3D_obj_centroid: + objn_centroid = rp.get_centroid(old_ID, exact=True) # + yn, xn = self.getObjCentroid(objn_centroid) + if not math.isnan(yn): + yn, xn = int(yn), int(xn) + posData.editID_info.append((yn, xn, new_ID)) + yo, xo = int(clicked_y), int(clicked_x) + posData.editID_info.append((yo, xo, old_ID)) + assignments[new_ID] = old_ID + assignments[old_ID] = new_ID else: - lab[lab == old_ID] = new_ID - if new_ID > maxID: - maxID = new_ID - old_ID_idx = posData.IDs.index(old_ID) - - # Append information for replicating the edit in tracking - # List of tuples (y, x, replacing ID) - obj = posData.rp[old_ID_idx] - y, x = self.getObjCentroid(obj.centroid) - if not math.isnan(y) and not math.isnan(y): - y, x = int(y), int(x) - posData.editID_info.append((y, x, new_ID)) + # Use regionprops for old_ID + obj = rp.get_obj_from_ID(old_ID) + slc = obj.slice + mask = obj.image + lab[slc][mask] = new_ID + if use_3D_obj_centroid: + centroid = rp.get_centroid(old_ID, exact=True) + y, x = self.getObjCentroid(centroid) + if not math.isnan(y) and not math.isnan(x): + y, x = int(y), int(x) + posData.editID_info.append((y, x, new_ID)) + assignments[old_ID] = new_ID self.updateAssignedObjsAcdcTrackerSecondStep(new_ID) @@ -9698,7 +10018,10 @@ def applyEditID( self.set_2Dlab(lab) # Update rps - self.update_rp() + # When shift is active we edited the 2D slice of a 3D lab; the cached + # slice/image data in the old RP objects is stale so we must do a full + # recompute rather than the fast assignments-only path. + self.update_rp(assignments=None if (shift and self.isSegm3D) else assignments) # Since we manually changed an ID we don't want to repeat tracking self.setAllTextAnnotations() @@ -9763,14 +10086,23 @@ def getHoverID(self, xdata, ydata, byPassShiftCheck=False): ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata) posData = self.data[self.pos_i] - lab_2D = self.get_2Dlab(posData.lab) + lab_2D = self.get_2Dlab(posData.lab, force_z=False) ID = lab_2D[ydata, xdata] self.isHoverZneighID = False if self.isSegm3D: + zProjHow = self.zProjComboBox.currentText() + isZslice = zProjHow == 'single z-slice' z = self.z_lab() SizeZ = posData.lab.shape[0] doNotLinkThroughZ = self.brushButton.isChecked() and shift - if doNotLinkThroughZ: + if not isZslice: + # In projection mode, ID comes from the projected 2D label image. + if self.brushHoverCenterModeAction.isChecked() or ID>0: + hoverID = ID + else: + masked_lab = lab_2D[ymin:ymax, xmin:xmax][diskMask] + hoverID = np.bincount(masked_lab).argmax() + elif doNotLinkThroughZ: if self.brushHoverCenterModeAction.isChecked() or ID>0: hoverID = ID else: @@ -9805,11 +10137,12 @@ def getHoverID(self, xdata, ydata, byPassShiftCheck=False): else: hoverIDc = 0 - if hoverIDa > 0: + # When clicking directly on an object, prefer current-slice ID. + if hoverIDb > 0: + hoverID = hoverIDb + elif hoverIDa > 0: hoverID = hoverIDa self.isHoverZneighID = True - elif hoverIDb > 0: - hoverID = hoverIDb elif hoverIDc > 0: hoverID = hoverIDc self.isHoverZneighID = True @@ -9840,7 +10173,7 @@ def setHoverToolSymbolColor( shift = modifiers == Qt.ShiftModifier posData = self.data[self.pos_i] - Y, X = self.get_2Dlab(posData.lab).shape + Y, X = self.get_2Dlab(posData.lab, force_z=False).shape if not myutils.is_in_bounds(xdata, ydata, X, Y): return @@ -10158,7 +10491,9 @@ def smoothAutoContWithSpline(self, n=3): xxA, yyA = xx[::n], yy[::n] rr, cc = skimage.draw.polygon(yyA, xxA) self.autoContObjMask[rr, cc] = 1 - rp = skimage.measure.regionprops(self.autoContObjMask) + rp = regionprops.acdcRegionprops( + self.autoContObjMask, precache_centroids=False + ) if not rp: return obj = rp[0] @@ -10258,14 +10593,6 @@ def annotateIsHistoryKnown(self, ID): # If the cell with unknown history has a relative ID assigned to it # we set the cca of it to the status it had BEFORE the assignment posData.cca_df.loc[relID] = relID_cca - - # Update cell cycle info LabelItems - obj_idx = posData.IDs.index(ID) - rp_ID = posData.rp[obj_idx] - - if relID in posData.IDs: - relObj_idx = posData.IDs.index(relID) - rp_relID = posData.rp[relObj_idx] self.setAllTextAnnotations() self.drawAllMothBudLines() @@ -10486,12 +10813,7 @@ def undoBudMothAssignment(self, ID): posData.cca_df.at[relID, 'generation_num'] = 2 posData.cca_df.at[relID, 'cell_cycle_stage'] = 'G1' posData.cca_df.at[relID, 'relationship'] = 'mother' - - obj_idx = posData.IDs.index(ID) - relObj_idx = posData.IDs.index(relID) - rp_ID = posData.rp[obj_idx] - rp_relID = posData.rp[relObj_idx] - + self.store_cca_df() # Update cell cycle info LabelItems @@ -11553,14 +11875,11 @@ def delBorderObj(self, checked): self.storeUndoRedoStates(False) posData = self.data[self.pos_i] - posData.lab = skimage.segmentation.clear_border( - posData.lab, buffer_size=1 + edge_ids = myutils.clear_border(posData.lab, return_edge_ids=True) # modifies inplace + self.update_rp(deletionIDs=edge_ids) + self.update_cca_df_deletedIDs( + posData, edge_ids, dropInPast=False, dropInFuture=False ) - oldIDs = posData.IDs.copy() - self.update_rp() - removedIDs = [ID for ID in oldIDs if ID not in posData.IDs] - if posData.cca_df is not None: - posData.cca_df = posData.cca_df.drop(index=removedIDs) self.store_data() self.updateAllImages() @@ -11574,7 +11893,7 @@ def delNewObj(self, checked): if frame_i == 0: return - prev_IDs = posData.allData_li[frame_i-1]['IDs'] + prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs curr_IDs = posData.IDs new_IDs = list(set(curr_IDs) - set(prev_IDs)) @@ -11583,7 +11902,7 @@ def delNewObj(self, checked): lab[del_mask] = 0 posData.lab = lab - self.update_rp() + self.update_rp(deletionIDs=new_IDs) if posData.cca_df is not None: posData.cca_df = posData.cca_df.drop(index=new_IDs) @@ -11602,16 +11921,21 @@ def brushAutoHideToggled(self, checked): def brushReleased(self): posData = self.data[self.pos_i] - self.fillHolesID(posData.brushID, sender='brush') + do_auto_fill = self.brushAutoFillCheckbox.isChecked() + self.fillHolesID(posData.brushID, sender='brush', enabled=do_auto_fill) # Update data (rp, etc) - self.update_rp(update_IDs=self.isNewID,) + + power_brush = self.isPowerBrush() + # we have to delay for a second + self.update_rp( + use_curr_view=True, + specific_IDs=posData.brushID if not power_brush else None + ) # Repeat tracking if self.autoIDcheckbox.isChecked(): self.trackManuallyAddedObject(posData.brushID, self.isNewID) - else: - self.update_rp(update_IDs=posData.brushID not in posData.IDs_idxs) # Update images if self.isNewID: @@ -11783,7 +12107,7 @@ def delROImoving(self, roi): def delROImovingFinished(self, roi: pg.ROI): roi.setPen(color='r') - self.update_rp() + self.update_rp() # get bbox of delROI old and new, run update_rp on both seperately self.updateAllImages() QTimer.singleShot( 300, partial(self.updateDelROIinFutureFrames, roi) @@ -11821,7 +12145,7 @@ def restoreAnnotDelROI(self, roi, enforce=True, draw=True): delROIs_info['delIDsROI'][idx] = delIDs - restoredIDs self.set_2Dlab(lab2D) - self.update_rp() + self.update_rp() # get bbox of delROI old and new, run update_rp on both seperately def restoreDelROIimg1(self, delMaskID, delID, ax=0): if ax == 0: @@ -11833,7 +12157,9 @@ def restoreDelROIimg1(self, delMaskID, delID, ax=0): return if how.find('contours') != -1: - rp_delmask = skimage.measure.regionprops(delMaskID.astype(np.uint8)) + rp_delmask = regionprops.acdcRegionprops( + delMaskID.astype(np.uint8), precache_centroids=False + ) if len(rp_delmask) > 0: obj = rp_delmask[0] self.addObjContourToContoursImage(obj=obj, ax=ax) @@ -11903,7 +12229,9 @@ def getDelROIlab(self, input_lab_2D=None): idx = delROIs_info['rois'].index(roi) delObjROImask = delROIs_info['delMasks'][idx] delIDsROI = delROIs_info['delIDsROI'][idx] - delROIlabRp = skimage.measure.regionprops(out_lab) + delROIlabRp = regionprops.acdcRegionprops( + out_lab, precache_centroids=False + ) for delObj in delROIlabRp: isDelObj = np.any(ROImask[delObj.slice][delObj.image]) if not isDelObj: @@ -12918,7 +13246,6 @@ def changeMode(self, text): self.addExistingDelROIs() self.isFirstTimeOnNextFrame() self.setEnabledCcaToolbar(enabled=False) - self.clearComputedContours() self.realTimeTrackingToggle.setDisabled(False) self.realTimeTrackingToggle.label.setDisabled(False) if posData.cca_df is not None: @@ -12935,9 +13262,7 @@ def changeMode(self, text): self.modeToolBar.setVisible(True) self.realTimeTrackingToggle.setDisabled(True) self.realTimeTrackingToggle.label.setDisabled(True) - self.computeAllContours() # RAWR!!!!! - # self.computeAllObjToObjCostPairs() if proceed: self.setEnabledEditToolbarButton(enabled=False) if self.isSnapshot: @@ -12959,7 +13284,6 @@ def changeMode(self, text): self.navigateScrollBar.setMaximum(posData.SizeT) self.navSpinBox.setMaximum(posData.SizeT) self.clearGhost() - self.computeAllContours() elif mode == 'Custom annotations': self.setAutoSaveAnnotationsEnabled(True) self.setSwitchViewedPlaneDisabled(True) @@ -12972,14 +13296,12 @@ def changeMode(self, text): self.annotateToolbar.setVisible(True) self.clearGhost() self.doCustomAnnotation(0) - self.computeAllContours() elif mode == 'Snapshot': self.setAutoSaveAnnotationsEnabled(True) self.setSwitchViewedPlaneDisabled(False) self.reconnectUndoRedo() self.setEnabledSnapshotMode() self.doCustomAnnotation(0) - self.clearComputedContours() elif mode == 'Normal division: Lineage tree': # Mode activation for lineage tree # self.startLinTreeIntegrityCheckerWorker() # need to replace (postponed) proceed = self.initLinTree() @@ -13280,8 +13602,7 @@ def manualAnnotPast_cb(self, checked): ) self.editIDspinbox.setValue(hoverID) try: - obj_idx = posData.IDs_idxs[hoverID] - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(hoverID) radius = 0.9 * obj.minor_axis_length / 2 # math.sqrt(obj.area/math.pi)*0.9 self.brushSizeSpinbox.setValue(round(radius)) except Exception as err: @@ -13384,16 +13705,15 @@ def _copyAllLostObjects_navigateToFrame(self, frame_i): self.get_data() self.tracking(wl_update=False) self.currentLab2D = self.get_2Dlab(posData.lab) - self.update_rp() + self.update_rp() # cannot be more granular as lost obj could be anywhere self.updateLostNewCurrentIDs() self.store_data(mainThread=False, autosave=False) self.lostObjContoursImage[:] = 0 self.lostObjImage[:] = 0 prev_rp = posData.allData_li[frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] # need to change this when merging with opt. for lostID in posData.lost_IDs: - obj = prev_rp[prev_IDs_idxs[lostID]] + obj = prev_rp.get_obj_from_ID(lostID) self.addLostObjsToLostObjImage(obj, lostID, force=True) def _copyAllLostObjects_returnToFrame(self, frame_i): @@ -13495,7 +13815,7 @@ def copyAllLostObjectsWorkerFinished(self, output): self.blinker.start() self.copyAllLostObjectsWorkerLoop.exit() - self.update_rp() + self.update_rp() # global op and obj added, no opt imo unless difference pic self.updateAllImages() self.store_data() @@ -13638,7 +13958,9 @@ def clearObjsFreehandRegion(self): regionLab = transformation.clear_objects_not_in_mask( regionLab, mask ) - regionRp = skimage.measure.regionprops(regionLab) + regionRp = regionprops.acdcRegionprops( + regionLab, precache_centroids=False + ) for obj in regionRp: if np.all(mask[obj.slice][obj.image]): continue @@ -13652,7 +13974,9 @@ def clearObjsFreehandRegion(self): else: regionLab[..., ~mask] = 0 - regionRp = skimage.measure.regionprops(regionLab) + regionRp = regionprops.acdcRegionprops( + regionLab, precache_centroids=False + ) clearIDs = [obj.label for obj in regionRp] if not clearIDs: @@ -13711,7 +14035,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): frame_i = start_frame_i + i lab = posData.allData_li[frame_i]['labels'] store = True - if lab is None: + if lab is None: # no rp update here? if frame_i >= len(posData.segm_data): lab = np.zeros_like(posData.segm_data[0]) posData.segm_data = np.append( @@ -13727,6 +14051,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): if store: posData.frame_i = frame_i posData.allData_li[frame_i]['labels'] = lab.copy() + # no rp update here? self.get_data() self.store_data(autosave=False) @@ -13739,7 +14064,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): roiLab, self.labelRoiSlice, posData.lab, posData.brushID ) - self.update_rp() + self.update_rp() # get roi and set as bbox # Repeat tracking if self.autoIDcheckbox.isChecked(): @@ -13770,8 +14095,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse): def restoreHoverObjBrush(self): posData = self.data[self.pos_i] if self.ax1BrushHoverID in posData.IDs: - obj_idx = posData.IDs_idxs[self.ax1BrushHoverID] - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(self.ax1BrushHoverID) if not self.isObjVisible(obj.bbox): return @@ -13863,15 +14187,20 @@ def setAllIDs(self, onlyVisited=False): for frame_i in range(len(posData.segm_data)): if frame_i >= len(posData.allData_li): break + lab = posData.allData_li[frame_i]['labels'] if lab is None and onlyVisited: break - if lab is None: - rp = skimage.measure.regionprops(posData.segm_data[frame_i]) - else: - rp = posData.allData_li[frame_i]['regionprops'] - posData.allIDs.update([obj.label for obj in rp]) + rp = posData.allData_li[frame_i]['regionprops'] + if rp is None: + lab = posData.segm_data[frame_i] + rp = regionprops.acdcRegionprops( + lab, precache_centroids=False + ) + posData.allData_li[frame_i]['regionprops'] = rp + + posData.allIDs.update(rp.IDs) def countObjectsTimelapse(self): if self.countObjsWindow is None: @@ -13920,11 +14249,13 @@ def countObjectsSnapshots(self): numObjectsCurrentZslice = None if 'In current z-slice' in activeCategories: numObjectsCurrentZslice = len( - skimage.measure.regionprops(self.currentLab2D) + regionprops.acdcRegionprops( + self.currentLab2D, precache_centroids=False + ) ) for pos_i, _posData in enumerate(self.data): - IDs = _posData.allData_li[0]['IDs'] + IDs = _posData.allData_li[0]['regionprops'].IDs if os.path.exists(_posData.acdc_output_csv_path): numObjectsVisitedPosPrevious += len(IDs) if IDs: @@ -13932,7 +14263,9 @@ def countObjectsSnapshots(self): numObjectsAllPos += len(IDs) else: lab = _posData.segm_data[0] - rp = skimage.measure.regionprops(lab) + rp = regionprops.acdcRegionprops( + lab, precache_centroids=False + ) numObjs = len(rp) numObjectsAllPos += numObjs @@ -14814,10 +15147,10 @@ def keyPressEvent(self, ev): if ev.key() == Qt.Key_Q and self.debug: try: from . import _q_debug - _q_debug.q_debug(self) + _q_debug.q_debug(self, ev) except Exception as err: printl(traceback.format_exc()) - printl('[ERROR]: Error with "_qdebug" module. See Traceback above.') + printl('[ERROR]: Error with "_q_debug" module. See Traceback above.') pass if not self.isDataLoaded: @@ -15159,7 +15492,7 @@ def propagateMergeObjsPast(self, IDs_to_merge): posData.frame_i = past_frame_i self.get_data() - IDs = posData.allData_li[past_frame_i]['IDs'] + IDs = posData.allData_li[past_frame_i]['regionprops'].IDs stop_loop = False for ID in IDs_to_merge: if ID not in IDs: @@ -15168,10 +15501,13 @@ def propagateMergeObjsPast(self, IDs_to_merge): if ID == 0: continue - posData.lab[posData.lab==ID] = self.firstID - self.update_rp() - - self.store_data(autosave=False) + obj = posData.rp.get_obj_from_ID(ID) + posData.lab[obj.slice][obj.image] = self.firstID + + preloaded_bbox = self.update_rp_get_bbox(specific_IDs=IDs_to_merge,use_bbox=True) # use old RP to get the correct bbox + specific_IDs = [*IDs_to_merge, self.firstID] + self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=specific_IDs) + self.store_data(autosave=False) if stop_loop: break @@ -15210,11 +15546,10 @@ def propagateChange( # Stop at last visited frame since includeUnvisited = False break else: - lab = posData.segm_data[i] + IDs = posData.allData_li[i]['regionprops'].IDs else: - lab = posData.allData_li[i]['labels'] - - if modID in lab: + IDs = posData.allData_li[i]['regionprops'].IDs + if modID in IDs: areFutureIDs_affected.append(True) if not last_tracked_i_found: @@ -15694,13 +16029,19 @@ def warnTrackerInputNotValid(self, trackerName, warningText): def repeatTracking(self): posData = self.data[self.pos_i] - prev_lab = self.get_2Dlab(posData.lab).copy() - self.tracking(enforce=True, DoManualEdit=False) + tracked_lab, assignments = self.tracking( + enforce=True, + DoManualEdit=False, + return_assignments=True, + return_lab=True + ) + posData.lab = tracked_lab if posData.editID_info: + lab2D = self.get_2Dlab(posData.lab) editedIDsInfo = { - posData.lab[y,x]:newID + lab2D[y,x]:newID for y, x, newID in posData.editID_info - if posData.lab[y,x] != newID + if lab2D[y,x] != newID } editedIDsInfoItems = [ f'ID {oldID} --> {newID}' @@ -15726,18 +16067,26 @@ def repeatTracking(self): detailsText=editIDul ) if msg.cancel: + self.update_rp(assignments=assignments) # rp now stale as we return img return if msg.clickedButton == keepManualEditButton: - allIDs = [obj.label for obj in posData.rp] + allIDs = posData.rp.IDs lab2D = self.get_2Dlab(posData.lab) - self.manuallyEditTracking(lab2D, allIDs) - self.update_rp() + tracked_lab, assignments = self.manuallyEditTracking( + lab2D, assignments) # here not use tracked lab? + self.update_rp(assignments=assignments) # rp now stale as we return img self.setAllTextAnnotations() self.highlightLostNew() # self.checkIDsMultiContour() else: + self.update_rp(assignments=assignments) # rp now stale as we return img posData.editID_info = [] - if np.any(posData.lab != prev_lab): + else: + self.update_rp(assignments=assignments) + + # filter self assignments + assignments = {k: v for k, v in assignments.items() if k != v} + if assignments: if self.isSnapshot: self.fixCcaDfAfterEdit('Repeat tracking') self.updateAllImages() @@ -15809,8 +16158,7 @@ def initManualBackgroundObject(self, ID=None): self.manualBackgroundObjItem.clear() return - ID_idx = posData.IDs_idxs[ID] - self.manualBackgroundObj = posData.rp[ID_idx] + self.manualBackgroundObj = posData.rp.get_obj_from_ID(ID) self.manualBackgroundToolbar.clearInfoText() self.manualBackgroundObj.contour = self.getObjContours( @@ -16908,15 +17256,14 @@ def doCustomAnnotation(self, ID): xx, yy = [], [] for annotID in annotIDs_frame_i: - if annotID not in posData.IDs_idxs: + if annotID not in posData.rp.IDs: continue - - obj_idx = posData.IDs_idxs[annotID] - obj = posData.rp[obj_idx] + obj = posData.rp.get_obj_from_ID(annotID) acdc_df.at[annotID, state['name']] = 1 if not self.isObjVisible(obj.bbox): continue - y, x = self.getObjCentroid(obj.centroid) + y, x = self.getObjCentroid( + posData.rp.get_centroid(annotID, exact=True)) xx.append(x) yy.append(y) @@ -17564,11 +17911,13 @@ def segmWorkerFinished(self, lab, exec_time): self.update_rp(wl_update=False) self.tracking(enforce=True, against_next=posData.frame_i==0) - if self.isSnapshot: - self.fixCcaDfAfterEdit('Repeat segmentation') - self.updateAllImages() - else: - self.warnEditingWithCca_df('Repeat segmentation') + proceed = self.checkHandleTooManyNewItems() + if proceed: + if self.isSnapshot: + self.fixCcaDfAfterEdit('Repeat segmentation') + self.updateAllImages() + else: + self.warnEditingWithCca_df('Repeat segmentation') txt = f'Done. Segmentation computed in {exec_time:.3f} s' self.logger.info('-----------------') @@ -17788,7 +18137,7 @@ def zoomToCells(self, enforce=False): posData = self.data[self.pos_i] lab_mask = (self.currentLab2D>0).astype(np.uint8) - rp = skimage.measure.regionprops(lab_mask) + rp = regionprops.acdcRegionprops(lab_mask, precache_centroids=False) if not rp: Y, X = lab_mask.shape xRange = -0.5, X+0.5 @@ -18203,9 +18552,8 @@ def warnLostObjects(self, do_warn=True): posData.accepted_lost_IDs[frame_i].extend(posData.lost_IDs) # This section is adding the lost cells to tracked_lost_centroids... TBH I dont know why this wasnt done in the first place prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] accepted_lost_centroids = { - tuple(int(val) for val in prev_rp[prev_IDs_idxs[ID]].centroid) + tuple(int(val) for val in prev_rp.get_centroid(ID, exact=True)) for ID in posData.lost_IDs } try: @@ -18285,8 +18633,7 @@ def checkIfFutureFrameManualAnnotPastFrames(self): self.statusBarLabel.setText(f'

{warn_txt}

') return False - - # @exec_time + def next_frame(self, warn=True): proceed = self.checkIfFutureFrameManualAnnotPastFrames() if not proceed: @@ -18350,7 +18697,6 @@ def next_frame(self, warn=True): ) return - self.store_zslices_rp() self.resetExpandLabel() self.updateAllImages() self.updateHighlightedAxis() @@ -18815,6 +19161,7 @@ def loadSelectedData(self, user_ch_file_paths, user_ch_name): create_new_segm=self.isNewFile, new_endname=self.newSegmEndName, end_filename_segm=selectedSegmEndName, + load_segm_info_ini=True ) self.selectedSegmEndName = selectedSegmEndName self.labelBoolSegm = posData.labelBoolSegm @@ -20240,7 +20587,7 @@ def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False): xxS, yyS = self.curvPlotItem.getData() if xxS is None: self.setUncheckedAllButtons() - return + return None, None self.smoothAutoContWithSpline() xxS, yyS = self.getClosedSplineCoords() @@ -20264,6 +20611,7 @@ def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False): lab2D[newIDMask] = curvToolID self.set_2Dlab(lab2D) self.currentLab2D = lab2D + return newIDMask, curvToolID def addFluoChNameContextMenuAction(self, ch_name): posData = self.data[self.pos_i] @@ -20525,6 +20873,10 @@ def framesScrollBarMoved(self, frame_n): posData.lab = posData.segm_data[posData.frame_i] else: posData.lab = np.zeros_like(posData.segm_data[0]) + rp = regionprops.acdcRegionprops(posData.lab, precache_centroids=False) + posData.rp = rp + posData.IDs = [] + posData.allData_li[posData.frame_i]['regionprops'] = rp else: posData.lab = posData.allData_li[posData.frame_i]['labels'] @@ -20547,7 +20899,10 @@ def framesScrollBarMoved(self, frame_n): def framesScrollBarReleased(self, do_store_data=False): posData = self.data[self.pos_i] - if posData.frame_i == self.navigateScrollBar.sliderPosition()-1: + if ( + posData.frame_i == self.navigateScrollBar.sliderPosition()-1 + and self.navigateScrollBarStartedMoving + ): # Slider released without changing value --> do nothing return @@ -20575,37 +20930,67 @@ def getStoredSegmData(self): segm_data.append(lab) return np.array(segm_data) - def trackNewIDtoNewIDsFutureFrame(self, newID, newIDmask): + def trackNewIDtoNewIDsFutureFrame(self, newID, obj, assignments): + # here RP is stale posData = self.data[self.pos_i] try: nextLab = posData.allData_li[posData.frame_i+1]['labels'] except IndexError: # This is last frame --> there are no future frames - return + return None, assignments if nextLab is None: - return + return None, assignments + + if obj is None: + return None, assignments + - newID_lab = np.zeros_like(posData.lab) - newID_lab[newIDmask] = newID - newLab_rp = [posData.rp[posData.IDs_idxs[newID]]] - newLab_IDs = [newID] nextRp = posData.allData_li[posData.frame_i+1]['regionprops'] + nextLab = posData.allData_li[posData.frame_i+1]['labels'] + reverse_assignments = {v:k for k, v in assignments.items()} - tracked_lab = self.trackFrame( - nextLab, nextRp, newID_lab, newLab_rp, newLab_IDs, - assign_unique_new_IDs=False + rp = posData.rp + lab = posData.lab + + # make rp remporarliy not stale anymore + rp.update_regionprops_via_assignments(assignments, lab) + tracked_lab, assignments_new = self.trackFrame( + nextLab, nextRp, lab, rp, rp.IDs, + assign_unique_new_IDs=False, return_assignments=True, + specific_IDs=[newID], ) - trackedID = tracked_lab[newID_lab>0][0] + # restore rp + posData.rp.update_regionprops_via_assignments(reverse_assignments, lab) + + # clear self assignments + assignments_new = { + k:v for k, v in assignments_new.items() if k != v + } + if not assignments_new: + return None, assignments + + trackedIDs = list(assignments_new.values()) + + trackedID = trackedIDs[0] if trackedID == newID: # Object does not exist in future frame --> do not track - return + return None, assignments - if posData.IDs_idxs.get(trackedID) is not None: + if posData.rp.get_obj_from_ID(trackedID, warn=False) is not None: # Tracked ID already exists --> do not track to avoid merging - return + return None, assignments - return trackedID + + + # update assignments + assignments = { + old_ID: tracked_ID for old_ID, tracked_ID in assignments.items() + if old_ID != newID + } + assignments[newID] = trackedID + + return trackedID, assignments def store_manual_annot_data( self, posData=None, data_frame_i=None @@ -20648,21 +21033,19 @@ def store_data( # self.lin_tree_ask_changes() allData_li = posData.allData_li[posData.frame_i] - allData_li['regionprops'] = posData.rp.copy() + + + allData_li['regionprops'] = posData.rp allData_li['labels'] = posData.lab.copy() - allData_li['IDs'] = posData.IDs.copy() + allData_li['regionprops'].IDs = posData.IDs allData_li['manualBackgroundLab'] = ( posData.manualBackgroundLab ) - allData_li['IDs_idxs'] = ( - posData.IDs_idxs.copy() - ) if self.manualAnnotPastButton.isChecked(): self.store_manual_annot_data( - posData=posData, data_frame_i=allData_li + posData=posData, data_frame_i=allData_li ) - self.store_zslices_rp() # Store dynamic metadata is_cell_dead_li = [False]*len(posData.rp) @@ -20678,13 +21061,17 @@ def store_data( is_cell_dead_li[i] = obj.dead is_cell_excluded_li[i] = obj.excluded IDs[i] = obj.label - try: - xx_centroid[i] = int(self.getObjCentroid(obj.centroid)[1]) - yy_centroid[i] = int(self.getObjCentroid(obj.centroid)[0]) - except Exception as err: - printl(obj, obj.centroid, obj.label, posData.frame_i) + centroid = posData.rp.get_centroid(obj.label, exact=True) + if centroid is None: + continue + if self.isSegm3D: - zz_centroid[i] = int(obj.centroid[0]) + zz_centroid[i] = int(centroid[0]) + xx_centroid[i] = int(centroid[2]) + yy_centroid[i] = int(centroid[1]) + else: + xx_centroid[i] = int(centroid[1]) + yy_centroid[i] = int(centroid[0]) if obj.label in editedNewIDs: areManuallyEdited[i] = 1 @@ -21366,7 +21753,18 @@ def get_2Dlab(self, lab, force_z=True): if isZslice: return lab[self.z_lab()] else: - return lab.max(axis=0) + if self.switchPlaneCombobox.isEnabled(): + slicing = self.switchPlaneCombobox.depthAxes() + else: + slicing = 'z' + + posData = self.data[self.pos_i] + rp = getattr(posData, 'rp', None) + if rp is not None and rp.is3D: + return rp.get_projection_lab_sorted(slicing=slicing) + + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) + return rp.get_projection_lab_sorted(slicing=slicing) else: return lab @@ -21455,24 +21853,13 @@ def applyBrushMask(self, mask, ID, toLocalSlice=None): posData.lab[mask] = ID def assignNewIDfromClickedID( - self, clickedID: int, event: QGraphicsSceneMouseEvent + self, clickedID: int, event: QGraphicsSceneMouseEvent, shift: bool = False ): posData = self.data[self.pos_i] x, y = event.pos().x(), event.pos().y() newID = self.setBrushID(return_val=True) mapper = [(clickedID, newID)] - self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y) - - def get_2Drp(self, lab=None): - if self.isSegm3D: - if lab is None: - # self.currentLab2D is defined at self.setImageImg2() - lab = self.currentLab2D - lab = self.get_2Dlab(lab) - rp = skimage.measure.regionprops(lab) - return rp - else: - return self.data[self.pos_i].rp + self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y, shift=shift) def set_2Dlab(self, lab2D, lab3D=None): posData = self.data[self.pos_i] @@ -21555,6 +21942,11 @@ def get_labels( else: shape = (posData.SizeY, posData.SizeX) labels = np.zeros(shape, dtype=np.uint32) + rp = regionprops.acdcRegionprops(labels, precache_centroids=False) + if frame_i == posData.frame_i: + posData.rp = rp + posData.IDs = [] + posData.allData_li[frame_i]['regionprops'] = rp return_copy = False if return_copy: @@ -21568,10 +21960,12 @@ def get_labels( def addYXcentroidToDf(self, df): posData = self.data[self.pos_i] for obj in posData.rp: - y_centroid = int(self.getObjCentroid(obj.centroid)[0]) - x_centroid = int(self.getObjCentroid(obj.centroid)[1]) - df.at[obj.label, 'y_centroid'] = y_centroid - df.at[obj.label, 'x_centroid'] = x_centroid + ID = obj.label + centroid = posData.rp.get_centroid(obj, exact=True) + y_centroid = int(self.getObjCentroid(centroid)[0]) + x_centroid = int(self.getObjCentroid(centroid)[1]) + df.at[ID, 'y_centroid'] = y_centroid + df.at[ID, 'x_centroid'] = x_centroid return df def _get_editID_info(self, df): @@ -21651,7 +22045,10 @@ def _get_data_unvisited(self, posData, debug=False, lin_tree_init=True,): posData.lab = self.apply_manual_edits_to_lab_if_needed( labels ) - posData.rp = skimage.measure.regionprops(posData.lab) + posData.rp = posData.allData_li[posData.frame_i]['regionprops'] + if posData.rp is None: + posData.rp = regionprops.acdcRegionprops(labels, precache_centroids=False) + # get stored IDs self.setManualBackgroundLab() if posData.acdc_df is not None: @@ -21694,7 +22091,8 @@ def _get_data_visited(self, posData, debug=False, lin_tree_init=True,): # Requested frame was already visited. Load from RAM. never_visited = False posData.lab = self.get_labels(from_store=True) - posData.rp = skimage.measure.regionprops(posData.lab) + # posData.rp = regionprops.acdcRegionprops(posData.lab, precache_centroids=False) + posData.rp = posData.allData_li[posData.frame_i]['regionprops'] df = posData.allData_li[posData.frame_i]['acdc_df'] if df is None: posData.binnedIDs = set() @@ -21734,6 +22132,7 @@ def get_data(self, debug=False, lin_tree_init=True): else: self.undoAction.setDisabled(True) self.UndoCount = 0 + # If stored labels is None then it is the first time we visit this frame if posData.allData_li[posData.frame_i]['labels'] is None: proceed_cca, never_visited = self._get_data_unvisited( @@ -21746,12 +22145,12 @@ def get_data(self, debug=False, lin_tree_init=True): posData, lin_tree_init=lin_tree_init, debug=debug ) + if posData.rp is None: # + rp = regionprops.acdcRegionprops(posData.lab, precache_centroids=False) + posData.rp = rp + posData.allData_li[posData.frame_i]['regionprops'] = rp self.update_rp_metadata(draw=False) - posData.IDs = [obj.label for obj in posData.rp] - posData.IDs_idxs = { - ID:i for ID, i in zip(posData.IDs, range(len(posData.IDs))) - } - self.get_zslices_rp() + posData.IDs = posData.rp.IDs self.pointsLayerDfsToData(posData) return proceed_cca, never_visited @@ -21777,7 +22176,7 @@ def addIDBaseCca_df(self, posData, ID): def getBaseCca_df(self, with_tree_cols=False): posData = self.data[self.pos_i] - IDs = [obj.label for obj in posData.rp] + IDs = posData.rp.IDs cca_df = core.getBaseCca_df(IDs, with_tree_cols=with_tree_cols) return cca_df @@ -22339,6 +22738,32 @@ def get_cca_df(self, frame_i=None, return_df=False, debug=False): return cca_df else: posData.cca_df = cca_df + + def _changeIDhelper(self, lab, oldID, newID, rp, assignments): + did_find_newID = False + if newID in rp.IDs: # should here also self.editIDmergeIDs? + # Relabel old_ID to tempID, safe as RP is safe so no merging + objo = rp.get_obj_from_ID(oldID, warn=False) + if objo is not None: + slc_o = objo.slice + mask_o = objo.image + lab[slc_o][mask_o] = newID + assignments[oldID] = newID + # Relabel new_ID to old_ID + objn = rp.get_obj_from_ID(newID) # here warn, we check in the if if it should be there + objn_slice = objn.slice + objn_mask = objn.image + lab[objn_slice][objn_mask] = oldID + assignments[newID] = oldID + did_find_newID = True + else: + obj = rp.get_obj_from_ID(oldID, warn=False) + if obj is not None: + slc = obj.slice + mask = obj.image + lab[slc][mask] = newID + assignments[oldID] = newID + return did_find_newID def changeIDfutureFrames( self, endFrame_i, oldIDnewIDMapper, includeUnvisited, @@ -22355,6 +22780,7 @@ def changeIDfutureFrames( segmSizeT = len(posData.segm_data) for i in range(posData.frame_i+1, segmSizeT): + assignments = {} lab = posData.allData_li[i]['labels'] if lab is None and not includeUnvisited: self.enqAutosave() @@ -22366,49 +22792,47 @@ def changeIDfutureFrames( self.get_data(lin_tree_init=False) if shift and self.isSegm3D: lab = self.get_2Dlab(posData.lab) + rp = self.rpCurr2D() else: lab = posData.lab - + rp = posData.rp + if self.onlyTracking: self.tracking(enforce=True) elif not posData.IDs: continue else: - maxID = max(posData.IDs, default=0) + 1 for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in lab: - tempID = maxID + 1 # lab.max() + 1 - lab[lab == old_ID] = tempID - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - maxID += 1 - else: - lab[lab == old_ID] = new_ID - + self._changeIDhelper( + lab, old_ID, new_ID, rp, assignments) + if shift and self.isSegm3D: self.set_2Dlab(lab) - - self.update_rp(draw=False) + + self.update_rp( + draw=False, + assignments=assignments if not (shift and self.isSegm3D) else None) self.store_data(autosave=i==endFrame_i) elif includeUnvisited: # Unvisited frame (includeUnvisited = True) lab = posData.segm_data[i] if shift and self.isSegm3D: lab = self.get_2Dlab(lab) + rp = self.rpCurr2D(frame_i=i) else: lab = lab - + rp = posData.allData_li[i]['regionprops'] + + assignments = {} for old_ID, new_ID in oldIDnewIDMapper: - if new_ID in lab: - tempID = lab.max() + 1 - lab[lab == old_ID] = tempID - lab[lab == new_ID] = old_ID - lab[lab == tempID] = new_ID - else: - lab[lab == old_ID] = new_ID - + self._changeIDhelper( + lab, old_ID, new_ID, rp, assignments) + if shift and self.isSegm3D: posData.segm_data[i][self.z_lab()] = lab + rp.update_regionprops(lab) + else: + rp.update_regionprops_via_assignments(assignments, lab) # Back to current frame posData.frame_i = self.current_frame_i @@ -22682,14 +23106,11 @@ def drawObjMothBudLines(self, obj, posData, ax=0): scatterItem = self.getMothBudLineScatterItem(ax, isNew) relative_ID = cca_df_ID['relative_ID'] - try: - relative_rp_idx = posData.IDs_idxs[relative_ID] - except KeyError: - return - - relative_ID_obj = posData.rp[relative_rp_idx] - y1, x1 = self.getObjCentroid(obj.centroid) - y2, x2 = self.getObjCentroid(relative_ID_obj.centroid) + relative_ID_obj = posData.rp.get_obj_from_ID(relative_ID) + obj_centroid = posData.rp.get_centroid(ID) + rel_obj_centroid = posData.rp.get_centroid(relative_ID) + y1, x1 = self.getObjCentroid(obj_centroid) + y2, x2 = self.getObjCentroid(rel_obj_centroid) xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) scatterItem.addPoints(xx, yy) @@ -22731,14 +23152,14 @@ def drawAllLineageTreeLines(self): continue for ID in new_cells: - curr_obj = myutils.get_obj_by_label(rp, ID) + curr_obj = rp.get_obj_from_ID(ID) lin_tree_df_ID = lin_tree_df.loc[ID] # lin_tree_df_mother_ID = lin_tree_df_prev.loc[lin_tree_df_ID["parent_ID_tree"]] if lin_tree_df_ID["parent_ID_tree"] == -1: # make sure that new obj where the parents are not known get skipped continue - mother_obj = myutils.get_obj_by_label(prev_rp, lin_tree_df_ID["parent_ID_tree"]) + mother_obj = prev_rp.get_obj_from_ID(lin_tree_df_ID["parent_ID_tree"]) emerg_frame_i = lin_tree_df_ID["emerg_frame_i"] isNew = emerg_frame_i == frame_i @@ -22774,9 +23195,15 @@ def drawObjLin_TreeMothBudLines(self, ax, obj, mother_obj, isNew, ID=None): return scatterItem = self.getMothBudLineScatterItem(ax, isNew) - - y1, x1 = self.getObjCentroid(obj.centroid) - y2, x2 = self.getObjCentroid(mother_obj.centroid) + + posData = self.data[self.pos_i] + prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] + rp = posData.rp + if ID is None: + ID = obj.label + ID_mother = mother_obj.label + y1, x1 = self.getObjCentroid(rp.get_centroid(ID)) + y2, x2 = self.getObjCentroid(prev_rp.get_centroid(ID_mother)) xx, yy = core.get_line(y1, x1, y2, x2, dashed=True) scatterItem.addPoints(xx, yy) @@ -22809,69 +23236,16 @@ def getObjOptsSegmLabels(self, obj): objOpts = self.getObjTextAnnotOpts(obj, 'Draw only IDs', ax=1) return objOpts - - def store_zslices_rp(self, force_update=False): - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - are_zslices_rp_stored = ( - posData.allData_li[posData.frame_i].get('z_slices_rp') is not None - ) - if force_update or not are_zslices_rp_stored: - self._update_zslices_rp() - - posData.allData_li[posData.frame_i]['z_slices_rp'] = posData.zSlicesRp def removeObjectFromRp(self, delID): posData = self.data[self.pos_i] - rp = [] - IDs = [] - IDs_idxs = {} - idx = 0 - for obj in posData.rp: - if obj.label == delID: - continue - rp.append(obj) - IDs.append(obj.label) - IDs_idxs[obj.label] = idx - idx += 1 - - posData.rp = rp - posData.IDs = IDs - posData.IDs_idxs = IDs_idxs - - if not self.isSegm3D: - return - - zSlicesRp = {} - for z, zSliceRp in posData.zSlicesRp.items(): - if delID in zSliceRp: - continue + if not isinstance(delID, (list, set, tuple)): + delIDs = [delID] + else: + delIDs = list(delID) - zSlicesRp[z] = zSlicesRp - - posData.zSlicesRp = zSlicesRp - self.store_zslices_rp(force_update=True) - - def get_zslices_rp(self): - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - self.store_zslices_rp() - posData.zSlicesRp = posData.allData_li[posData.frame_i]['z_slices_rp'] - - # @exec_time - def _update_zslices_rp(self): - if not self.isSegm3D: - return - - posData = self.data[self.pos_i] - posData.zSlicesRp = {} - for z, lab2d in enumerate(posData.lab): - lab2d_rp = skimage.measure.regionprops(lab2d) - posData.zSlicesRp[z] = {obj.label:obj for obj in lab2d_rp} + posData.rp.update_regionprops_via_deletions(set(delIDs)) + posData.IDs = posData.rp.IDs def instructHowDeleteID(self): if 'showInfoDeleteObject' not in self.df_settings.index: @@ -22914,7 +23288,7 @@ def checkWarnDeletedIDwithEraser(self): for ID in self.erasedIDs: if ID == 0: continue - if ID in posData.IDs_idxs: + if posData.rp.get_obj_from_ID(ID, warn=False) is not None: continue self.instructHowDeleteID() @@ -22928,36 +23302,228 @@ def checkWarnDeletedIDwithEraser(self): return True return False + + def _get_entire_depth_axis_from_2D_cutout(self, cutout): + # cutout = (xl, xr), (yt, yb), z is always on the y position if depth axis is changed + # cutout is in the current view; return grouped ranges in the order + # expected by update_rp_get_bbox before conversion to NumPy bbox order. + posData = self.data[self.pos_i] + if self.isSegm3D: + depthAxes = self.switchPlaneCombobox.depthAxes() + if depthAxes == 'z': + # cutout is (x, y) and we prepend the full z range. + z_max = posData.SizeZ + return ((0, z_max), cutout[0], cutout[1]) + if depthAxes == 'y': + # cutout is (x, z); convert to (z, x, y). + y_max = posData.SizeY + return (cutout[1], cutout[0], (0, y_max)) + elif depthAxes == 'x': + # cutout is (y, z); convert to (z, x, y). + x_max = posData.SizeX + return (cutout[1], (0, x_max), cutout[0]) + else: + return cutout + + def _cutout_to_bbox(self, cutout): + """ + Convert grouped view ranges into a flat bbox in NumPy array order. + 2D input: ((x_min, x_max), (y_min, y_max)) → (y_min, x_min, y_max, x_max) + 3D input: ((z_min, z_max), (y_min, y_max), (x_min, x_max)) → (z_min, y_min, x_min, z_max, y_max, x_max) + """ + cutout = tuple( + (min(r[0], r[1]), max(r[0], r[1])) for r in cutout + ) + if self.isSegm3D: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = cutout + return (z_min, y_min, x_min, z_max, y_max, x_max) + else: + (x_min, x_max), (y_min, y_max) = cutout + return (y_min, x_min, y_max, x_max) + + def _get_perc_cutout_from_total_img(self, cutout): + posData = self.data[self.pos_i] + single_timepoint_segm_size = posData.getSingleTimepointSegmSize() + if self.isSegm3D: + size = (cutout[0][1] - cutout[0][0]) * (cutout[1][1] - cutout[1][0]) * (cutout[2][1] - cutout[2][0]) + else: + size = (cutout[0][1] - cutout[0][0]) * (cutout[1][1] - cutout[1][0]) + return size / single_timepoint_segm_size + + def update_rp_get_bbox(self, custom_bbox=None, use_bbox=False, use_curr_view=False, + specific_IDs=None, add_frac_custom_bbox=0.05): + """ + Returns an expanded bounding box (bbox) for the given IDs or custom_bbox. + Returns False if not enough cells or cutout is too large. + """ + posData = self.data[self.pos_i] + if len(posData.rp.IDs) < RP_OPT_NUM_CELLS_MIN: + return False + if not isinstance(specific_IDs, (list, set, np.ndarray)) and specific_IDs is not None: + specific_IDs = [specific_IDs] + elif specific_IDs is not None and len(specific_IDs) == 0: + specific_IDs = None + + # Helper to merge bboxes + def merge_bbox(b1, b2): + if len(b1) == 4: + return ( + min(b1[0], b2[0]), min(b1[1], b2[1]), + max(b1[2], b2[2]), max(b1[3], b2[3]) + ) + else: + return ( + min(b1[0], b2[0]), min(b1[1], b2[1]), min(b1[2], b2[2]), + max(b1[3], b2[3]), max(b1[4], b2[4]), max(b1[5], b2[5]) + ) + + bbox = None + if custom_bbox or use_bbox: + if not custom_bbox and use_bbox and specific_IDs is not None: + rp_old = posData.rp + for ID in specific_IDs: + b = rp_old.get_obj_from_ID(ID).bbox + bbox = b if bbox is None else merge_bbox(bbox, b) + else: + bbox = custom_bbox + + if bbox is None: + return False + + elif use_curr_view: + cutout = self.ax1ViewRange(integers=True) + cutout = self._get_entire_depth_axis_from_2D_cutout(cutout) + if len(cutout)==2: + (xl, xr), (yt, yb) = cutout + else: + (z1, z2), (xl, xr), (yt, yb) = cutout + z_min = min(z1, z2) + z_max = max(z1, z2) + x_min = min(xl, xr) + x_max = max(xl, xr) + y_min = min(yt, yb) + y_max = max(yt, yb) + bbox = (y_min, x_min, y_max, x_max) if len(cutout)==2 else (z_min, y_min, x_min, z_max, y_max, x_max) + # Expand bbox by a fraction + else: + raise ValueError('''Either custom_bbox or use_bbox or use_curr_view must be provided as True''') + + if len(bbox) == 4: + y_min, x_min, y_max, x_max = bbox + offset_y = int((y_max - y_min) * add_frac_custom_bbox) + offset_x = int((x_max - x_min) * add_frac_custom_bbox) + offset_y = 1 if offset_y == 0 else offset_y + offset_x = 1 if offset_x == 0 else offset_x + size_y, size_x = posData.SizeY, posData.SizeX + cutout = ( + (max(0, x_min - offset_x), min(size_x, x_max + offset_x)), + (max(0, y_min - offset_y), min(size_y, y_max + offset_y)) + ) + else: + z_min, y_min, x_min, z_max, y_max, x_max = bbox + offset_z = int((z_max - z_min) * add_frac_custom_bbox) + offset_y = int((y_max - y_min) * add_frac_custom_bbox) + offset_x = int((x_max - x_min) * add_frac_custom_bbox) + offset_z = 1 if offset_z == 0 else offset_z + offset_y = 1 if offset_y == 0 else offset_y + offset_x = 1 if offset_x == 0 else offset_x + size_z, size_y, size_x = posData.SizeZ, posData.SizeY, posData.SizeX + cutout = ( + (max(0, z_min - offset_z), min(size_z, z_max + offset_z)), + (max(0, y_min - offset_y), min(size_y, y_max + offset_y)), + (max(0, x_min - offset_x), min(size_x, x_max + offset_x)) + ) + + perc_from_global = self._get_perc_cutout_from_total_img(cutout) + if perc_from_global > RP_OPT_PERC_CUTOUT_MAX: + return False + return self._cutout_to_bbox(cutout) @exception_handler def update_rp( - self, draw=True, debug=False, update_IDs=True, - wl_update=True, wl_track_og_curr=False,wl_update_lab=False + self, draw=True, debug=False, # og stuff + assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same + specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR + wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff ): + """Updates posData.rp + Parameters + ---------- + + """ + #updating rp is very clostly, as it deletes all the cashed + if use_curr_view and use_bbox: + raise ValueError('''use_curr_view and use_bbox cannot be True at the + same time, as they are mutually exclusive''') + local_rp_update = bool(use_curr_view or use_bbox or preloaded_bbox) posData = self.data[self.pos_i] # Update rp for current posData.lab (e.g. after any change) - if wl_update: if self.whitelistOriginalIDs is None: - old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() # for whitelist stuff + old_IDs = posData.allData_li[posData.frame_i]['regionprops'].IDs.copy() # for whitelist stuff else: old_IDs = self.whitelistOriginalIDs.copy() self.whitelistOriginalIDs = None elif self.whitelistOriginalIDs is None: - self.whitelist_old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() - - posData.rp = skimage.measure.regionprops(posData.lab) - if update_IDs: - IDs = [] - IDs_idxs = {} - for idx, obj in enumerate(posData.rp): - IDs.append(obj.label) - IDs_idxs[obj.label] = idx - posData.IDs = IDs - posData.IDs_idxs = IDs_idxs + self.whitelist_old_IDs = ( + posData.allData_li[posData.frame_i]['regionprops'].IDs.copy()) + + # check if only one of assignments, deletionIDs or only_current_view is given + if sum([assignments is not None, + deletionIDs is not None, + local_rp_update, + ]) > 1: + print(assignments is not None, deletionIDs is not None, local_rp_update) + raise ValueError('Only one of assignments, deletionIDs, ' + 'use_curr_view or use_bbox, preloaded_bbox can be used ' + 'at a time') + + if not isinstance(specific_IDs, (list, set, np.ndarray)) and specific_IDs is not None: + specific_IDs = [specific_IDs] + elif specific_IDs is not None and len(specific_IDs) == 0: + specific_IDs = None + + + # posData.rp is an acdcRegionprops instance here. + # if rp is None (can sometimes happen appearantly???) + if posData.rp is None: + printl(f'''Warning: posData.rp is None for pos {self.pos_i}, + frame {posData.frame_i}. Recomputing rp from labels.''') + + posData.rp = regionprops.acdcRegionprops( + posData.lab, precache_centroids=False + ) + + if assignments is not None: + # {old_ID: new_ID, ...} + posData.rp.update_regionprops_via_assignments(assignments, posData.lab) + elif deletionIDs is not None: + # (delID1, delID2, ...) + posData.rp.update_regionprops_via_deletions(deletionIDs) + elif local_rp_update: + # first get current view + if preloaded_bbox is None: + preloaded_bbox = self.update_rp_get_bbox(use_bbox=use_bbox, use_curr_view=use_curr_view, + specific_IDs=specific_IDs) + if preloaded_bbox is not False: + posData.rp.update_regionprops_via_cutout( + posData.lab, cutout_bbox=preloaded_bbox, specific_IDs=specific_IDs + ) + # if ID touches border but is not in specific_IDs, it will not be updated, + # so be careful! + else: + posData.rp.update_regionprops( + posData.lab + ) + else: + posData.rp.update_regionprops( + posData.lab, + specific_IDs_update_centroids=specific_IDs if preloaded_bbox is not False else None, # since sometimes I preload + ) + posData.IDs = posData.rp.IDs + self.update_rp_metadata(draw=draw) - self.store_zslices_rp(force_update=True) if not wl_update: return @@ -23052,12 +23618,12 @@ def updateTempLayerKeepIDs(self): def highlightLabelID(self, ID, ax=0): posData = self.data[self.pos_i] - try: - obj = posData.rp[posData.IDs_idxs[ID]] - except KeyError: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: return - self.textAnnot[ax].highlightObject(obj) + self.textAnnot[ax].highlightObject( + obj, rp=posData.rp, getObjCentroidFunc=self.getObjCentroid) def _keepObjects(self, keepIDs=None, lab=None, rp=None): posData = self.data[self.pos_i] @@ -23087,7 +23653,7 @@ def removeHighlightLabelID(self, IDs=None, ax=0): IDs = posData.IDs for ID in IDs: - obj = posData.rp[posData.IDs_idxs[ID]] + obj = posData.rp.get_obj_from_ID(ID) self.textAnnot[ax].removeHighlightObject(obj) def updateKeepIDs(self, IDs): @@ -23241,7 +23807,9 @@ def applyKeepObjects(self): elif includeUnvisited: # Unvisited frame (includeUnvisited = True) lab = posData.segm_data[i] - rp = skimage.measure.regionprops(lab) + rp = regionprops.acdcRegionprops( + lab, precache_centroids=False + ) keepLab = self._keepObjects(lab=lab, rp=rp) posData.segm_data[i] = keepLab @@ -23323,7 +23891,8 @@ def annotate_rip_and_bin_IDs(self, updateLabel=False): continue if obj.excluded: - y, x = self.getObjCentroid(obj.centroid) + ID = obj.label + y, x = self.getObjCentroid(posData.rp.get_centroid(ID)) binnedIDs_xx.append(x) binnedIDs_yy.append(y) if updateLabel: @@ -23331,7 +23900,8 @@ def annotate_rip_and_bin_IDs(self, updateLabel=False): how = self.drawIDsContComboBox.currentText() if obj.dead: - y, x = self.getObjCentroid(obj.centroid) + ID = obj.label + y, x = self.getObjCentroid(posData.rp.get_centroid(ID)) ripIDs_xx.append(x) ripIDs_yy.append(y) if updateLabel: @@ -24011,10 +24581,8 @@ def zoomToObj(self, obj=None): posData = self.data[self.pos_i] if obj is None: ID = self.sender().value() - try: - ID_idx = posData.IDs_idxs[ID] - obj = obj = posData.rp[ID_idx] - except Exception as e: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: self.logger.warning( f'ID {ID} does not exist (add points by clicking)' ) @@ -24066,7 +24634,7 @@ def pointsLayerAutoPilot(self, direction): return try: - ID_idx = posData.IDs_idxs[ID] + ID_idx = posData.rp.ID_to_idx[ID] if direction == 'next': nextID_idx = ID_idx + 1 else: @@ -24194,7 +24762,7 @@ def checkLoadedTableIds(self, toolbar): for posData in self.data: for tableEndName, df in posData.clickEntryPointsDfs.items(): for point_id in df['id'].values: - if point_id in posData.IDs_idxs: + if point_id in posData.rp.IDs: proceed = self.warnAddingPointWithExistingId( point_id, table_endname=tableEndName ) @@ -24428,10 +24996,10 @@ def setHoverCircleAddPoint(self, x, y): def isPointIdAlreadyNew(self, point_id, action): posData = self.data[self.pos_i] - if point_id in posData.IDs_idxs: + if point_id in posData.rp.IDs: return False - is_ID = point_id in posData.IDs_idxs + is_ID = point_id in posData.rp.IDs pointsDataPos = action.pointsData.get(self.pos_i) if pointsDataPos is None: return not is_ID @@ -24577,6 +25145,8 @@ def getCentroidsPointsData(self, action): # Centroids (either weighted or not) # NOTE: if user requested to draw from table we load that in # apps.AddPointsLayerDialog.ok_cb() + + # this does not have the updated centroid logic to avoid weird behaviours posData = self.data[self.pos_i] action.pointsData[self.pos_i] = {posData.frame_i: {}} if hasattr(action, 'weighingData'): @@ -25170,30 +25740,42 @@ def initTextAnnot(self, force=False): Y, X = posData.img_data.shape[-2:] self.textAnnot[0].initItem((Y, X)) self.textAnnot[1].initItem((Y, X)) - + + def _get_obj_for_current_view_rp(self, obj, posData): + # 2D segmentation already has the correct regionprops object. + if not self.isSegm3D or not posData.rp.is3D: + return obj + + slicing = self.switchPlaneCombobox.depthAxes() + zProjHow = self.zProjComboBox.currentText() + if zProjHow == 'single z-slice': + slice_selector = self.z_lab() + slice_number = slice_selector[-1] if isinstance(slice_selector, tuple) else slice_selector + obj_current_view = posData.rp.get_obj_from_slice_rp( + obj.label, slice_number, slicing=slicing, warn=False + ) + return obj_current_view or obj + + obj_current_view = posData.rp.get_obj_from_proj_rp( + obj.label, kind='most_common', slicing=slicing, warn=False + ) + return obj_current_view or obj + def getObjContours( - self, obj, all_external=False, local=False, force_calc=True, - include_internal=False + self, obj, all_external=False, local=False, + include_internal=False, rp=None ): posData = self.data[self.pos_i] - dataDict = posData.allData_li[posData.frame_i] - allContours = dataDict.get('contours') - if allContours is not None and not force_calc: - z = self.z_lab() - key = (obj.label, str(z), all_external, local) - contours = allContours.get(key) - if contours is not None: - return contours + obj_to_use = self._get_obj_for_current_view_rp(obj, posData) - obj_image = self.getObjImage(obj.image, obj.bbox).astype(np.uint8) - obj_bbox = self.getObjBbox(obj.bbox) try: contours = core.get_obj_contours( - obj_image=obj_image, - obj_bbox=obj_bbox, + obj=obj_to_use, local=local, - all_external=all_external + all_external=all_external, + all=include_internal ) + except Exception as e: if all_external: contours = [] @@ -25202,168 +25784,9 @@ def getObjContours( self.logger.warning( f'Object ID {obj.label} contours drawing failed. ' f'(bounding box = {obj.bbox})' + f'Error: {e}' ) return contours - - def clearComputedContours(self): - for posData in self.data: - for frame_i, dataDict in enumerate(posData.allData_li): - dataDict['contours'] = {} - - def _computeAllContours2D( - self, dataDict, obj, z, obj_bbox, include_internal=False - ): - obj_image = self.getObjImage(obj.image, obj.bbox, z_slice=z) - if obj_image is None: - return - - all_external = False - local = False - contours = core.get_obj_contours( - obj_image=obj_image, - obj_bbox=obj_bbox, - local=local, - all_external=all_external - ) - key = (obj.label, str(z), all_external, local) - dataDict['contours'][key] = contours - - all_external = True - local = False - contours = core.get_obj_contours( - obj_image=obj_image, - obj_bbox=obj_bbox, - local=local, - all_external=all_external, - all=include_internal - ) - key = (obj.label, str(z), all_external, local) - dataDict['contours'][key] = contours - - return dataDict - - def computeAllContours(self): - self.logger.info('Computing all contours...') - posData = self.data[self.pos_i] - zz = [None] - if self.isSegm3D: - zz.extend(range(posData.SizeZ)) - - include_internal = self.showAllContoursToggle.isChecked() - for frame_i, dataDict in enumerate(posData.allData_li): - lab = dataDict['labels'] - if lab is None: - break - - rp = dataDict['regionprops'] - if rp is None: - rp = skimage.measure.regionprops(lab) - - dataDict['contours'] = {} - for obj in rp: - obj_bbox = self.getObjBbox(obj.bbox) - for z in zz: - if not self.isObjVisible(obj.bbox, z_slice=z): - continue - - try: - self._computeAllContours2D( - dataDict, obj, z, obj_bbox, - include_internal=include_internal - ) - except Exception as err: - # Contours computation fails on weird objects - pass - - def computeAllObjToObjCostPairs(self): - desc = ( - 'Computing all object-to-object cost matrices...' - ) - self.logger.info(desc) - posData = self.data[self.pos_i] - - - self.progressWin = apps.QDialogWorkerProgress( - title=desc, parent=self, pbarDesc=desc - ) - self.progressWin.mainPbar.setMaximum(0) - self.progressWin.show(self.app) - - self.computeAllObjCostPairsThread = QThread() - self.computeAllObjCostPairsWorker = workers.SimpleWorker( - posData, self._computeAllObjToObjCostPairs - ) - - self.computeAllObjCostPairsWorker.moveToThread( - self.computeAllObjCostPairsThread - ) - - self.computeAllObjCostPairsWorker.signals.finished.connect( - self.computeAllObjCostPairsThread.quit - ) - self.computeAllObjCostPairsWorker.signals.finished.connect( - self.computeAllObjCostPairsWorker.deleteLater - ) - self.computeAllObjCostPairsThread.finished.connect( - self.computeAllObjCostPairsThread.deleteLater - ) - - self.computeAllObjCostPairsWorker.signals.critical.connect( - self.computeAllObjCostPairsWorkerCritical - ) - self.computeAllObjCostPairsWorker.signals.initProgressBar.connect( - self.workerInitProgressbar - ) - self.computeAllObjCostPairsWorker.signals.progressBar.connect( - self.workerUpdateProgressbar - ) - self.computeAllObjCostPairsWorker.signals.progress.connect( - self.workerProgress - ) - self.computeAllObjCostPairsWorker.signals.finished.connect( - self.computeAllObjCostPairsWorkerFinished - ) - - self.computeAllObjCostPairsThread.started.connect( - self.computeAllObjCostPairsWorker.run - ) - self.computeAllObjCostPairsThread.start() - - self.computeAllObjCostPairsWorkerLoop = QEventLoop() - self.computeAllObjCostPairsWorkerLoop.exec_() - - def _computeAllObjToObjCostPairs(self, posData): - self.computeAllObjCostPairsWorker.signals.initProgressBar.emit( - len(posData.allData_li) - ) - for frame_i, dataDict in enumerate(posData.allData_li): - if frame_i == 0: - continue - - rp = dataDict['regionprops'] - if rp is None: - break - - prev_rp = posData.allData_li[frame_i-1]['regionprops'] - dist_matrix = core._compute_all_obj_to_obj_contour_dist_pairs( - dataDict['contours'], rp, - prev_rp=prev_rp, - restrict_search=True - ) - dataDict['obj_to_obj_dist_cost_matrix_df'] = dist_matrix - self.computeAllObjCostPairsWorker.signals.progressBar.emit(1) - self.computeAllObjCostPairsWorker.signals.initProgressBar.emit(0) - - def computeAllObjCostPairsWorkerCritical(self, error): - self.computeAllObjCostPairsWorkerLoop.exit() - self.workerCritical(error) - - def computeAllObjCostPairsWorkerFinished(self, output): - if self.progressWin is not None: - self.progressWin.workerFinished = True - self.progressWin.close() - self.progressWin = None - self.computeAllObjCostPairsWorkerLoop.exit() def setOverlaySegmMasks(self, force=False, forceIfNotActive=False): if not hasattr(self, 'currentLab2D'): @@ -25396,27 +25819,21 @@ def setOverlaySegmMasks(self, force=False, forceIfNotActive=False): self.extendLabelsLUT(maxID+10) currentLab2D = self.currentLab2D + if ( + self.isSegm3D + and self.zProjComboBox.currentText() != 'single z-slice' + and getattr(posData, 'rp', None) is not None + and posData.rp.is3D + ): + slicing = self.switchPlaneCombobox.depthAxes() + currentLab2D = posData.rp.get_projection_lab_sorted(slicing=slicing) + self.currentLab2D = currentLab2D + if isOverlaySegmLeftActive: self.labelsLayerImg1.setImage(currentLab2D, autoLevels=False) if isOverlaySegmRightActive: self.labelsLayerRightImg.setImage(currentLab2D, autoLevels=False) - - def getObject2DimageFromZ(self, z, obj): - posData = self.data[self.pos_i] - z_min = obj.bbox[0] - local_z = z - z_min - if local_z >= posData.SizeZ or local_z < 0: - return - return obj.image[local_z] - - def getObject2DsliceFromZ(self, z, obj): - posData = self.data[self.pos_i] - z_min = obj.bbox[0] - local_z = z - z_min - if local_z >= posData.SizeZ or local_z < 0: - return - return obj.image[local_z] def isObjVisible(self, obj_bbox, debug=False, z_slice=None): if z_slice is None: @@ -25457,13 +25874,12 @@ def getObjImage(self, obj_image, obj_bbox, z_slice=None): # required a projection return obj_image.max(axis=0) - min_z = obj_bbox[0] if z_slice is None: z_slice = self.z_lab() if isinstance(z_slice, tuple): z_slice = z_slice[-1] - local_z = z_slice - min_z + local_z = z_slice - obj_bbox[0] try: obi_image_2d = obj_image[local_z] except Exception as err: @@ -26338,6 +26754,7 @@ def applyDelROIs(self): self.get_data() def initTempLayerBrush(self, ID, ax=0): + posData = self.data[self.pos_i] if ax == 0: how = self.drawIDsContComboBox.currentText() else: @@ -26347,7 +26764,24 @@ def initTempLayerBrush(self, ID, ax=0): Y, X = self.img1.image.shape[:2] tempImage = np.zeros((Y, X), dtype=np.uint32) if how.find('contours') != -1: - tempImage[self.currentLab2D==ID] = ID + # Keep the currently edited object visible while painting. + obj = None + rp = getattr(posData, 'rp', None) + if rp is not None: + obj = rp.get_obj_from_ID(ID, warn=False) + if obj is not None: + obj = self._get_obj_for_current_view_rp(obj, posData) + + if ( + obj is not None + and hasattr(obj, 'slice') + and hasattr(obj, 'image') + and len(obj.slice) == 2 + and obj.image.ndim == 2 + ): + tempImage[obj.slice][obj.image] = ID + else: + tempImage[self.currentLab2D == ID] = ID self.brushImage = tempImage.copy() self.brushContourImage = np.zeros((Y, X, 4), dtype=np.uint8) color = self.imgGrad.contoursColorButton.color() @@ -26383,6 +26817,7 @@ def setTempBrushMaskFromWand(self, flood_mask, init=False): # @exec_time def setTempImg1Brush(self, init: bool, mask, ID, toLocalSlice=None, ax=0): + posData = self.data[self.pos_i] if init: self.initTempLayerBrush(ID, ax=ax) @@ -26397,14 +26832,11 @@ def setTempImg1Brush(self, init: bool, mask, ID, toLocalSlice=None, ax=0): brushImage[toLocalSlice][mask] = ID if self.annotContourCheckbox.isChecked(): - try: - obj = skimage.measure.regionprops(brushImage)[0] - except IndexError: - return - objContour = [self.getObjContours(obj)] - # objContour = core.get_obj_contours( - # obj_image=(brushImage>0).astype(np.uint8), local=True - # ) + brushMask = np.ascontiguousarray((brushImage > 0), dtype=np.uint8) + + objContour = core.get_obj_contours( + obj_image=brushMask, obj_bbox=None, all_external=True, local=True + ) self.brushContourImage[:] = 0 img = self.brushContourImage color = self.brushContoursRgba @@ -26447,9 +26879,22 @@ def setTempImg1Eraser(self, mask, init=False, toLocalSlice=None, ax=0): self.clearObjFromMask( self.contoursImage, mask, toLocalSlice=toLocalSlice ) - erasedRp = skimage.measure.regionprops(self.erasedLab) - for obj in erasedRp: - self.addObjContourToContoursImage(obj=obj, ax=ax) + thickness = self.contLineWeight + color = self.contLineColor + for erasedID in np.unique(self.erasedLab): + if erasedID == 0: + continue + erasedMask = np.ascontiguousarray( + (self.erasedLab == erasedID), dtype=np.uint8 + ) + contours = core.get_obj_contours( + obj_image=erasedMask, obj_bbox=None, + all_external=True, local=True + ) + cv2.drawContours(self.contoursImage, contours, -1, color, thickness) + imageItem = self.getContoursImageItem(ax) + if imageItem is not None: + imageItem.setImage(self.contoursImage) elif how.find('overlay segm. masks') != -1: labelsImage = self.getLabelsLayerImage(ax=ax) self.clearObjFromMask(labelsImage, mask, toLocalSlice=toLocalSlice) @@ -26462,13 +26907,13 @@ def setTempImg1Eraser(self, mask, init=False, toLocalSlice=None, ax=0): self.labelsLayerRightImg.image, autoLevels=False ) - def _setTempImgExpandLabelSegmMasks(self, prevCoords, ax=0): + def _setTempImgExpandLabelSegmMasks(self, prevCoords, expandedObjCoords, ax=0): # Remove previous overlaid mask labelsImage = self.getLabelsLayerImage(ax=ax) labelsImage[prevCoords] = 0 - # Overlay new moved mask - labelsImage[prevCoords] = self.expandingID + # Overlay new expanded mask + labelsImage[expandedObjCoords] = self.expandingID if ax == 0: self.labelsLayerImg1.setImage( @@ -26479,25 +26924,31 @@ def _setTempImgExpandLabelSegmMasks(self, prevCoords, ax=0): def _setTempImgExpandLabelContours(self, prevCoords, ax=0): self.contoursImage[prevCoords] = [0,0,0,0] - currentLab2Drp = skimage.measure.regionprops(self.currentLab2D) - for obj in currentLab2Drp: - if obj.label == self.expandingID: - # self.clearObjContour(obj=obj, ax=ax) - self.addObjContourToContoursImage(obj=obj, ax=ax, force=True) - break + expandMask = np.ascontiguousarray( + (self.currentLab2D == self.expandingID), dtype=np.uint8 + ) + contours = core.get_obj_contours( + obj_image=expandMask, obj_bbox=None, all_external=True, local=True + ) + imageItem = self.getContoursImageItem(ax, force=True) + if imageItem is not None: + thickness = self.contLineWeight + color = self.contLineColor + cv2.drawContours(self.contoursImage, contours, -1, color, thickness) + imageItem.setImage(self.contoursImage) def setTempImgExpandLabel(self, prevCoords, expandedObjCoords, ax=0): if ax == 0: how = self.drawIDsContComboBox.currentText() else: how = self.getAnnotateHowRightImage() - - self._setTempImgExpandLabelContours(prevCoords, ax=ax) - - # if how.find('overlay segm. masks') != -1: - # self._setTempImgExpandLabelSegmMasks(ax=ax) - # else: - # self._setTempImgExpandLabelContours(ax=ax) + + if how.find('overlay segm. masks') != -1: + self._setTempImgExpandLabelSegmMasks( + prevCoords, expandedObjCoords, ax=ax + ) + else: + self._setTempImgExpandLabelContours(prevCoords, ax=ax) def setTempImg1MoveLabel(self, ax=0): if ax == 0: @@ -26506,11 +26957,18 @@ def setTempImg1MoveLabel(self, ax=0): how = self.getAnnotateHowRightImage() if how.find('contours') != -1: - currentLab2Drp = skimage.measure.regionprops(self.currentLab2D) - for obj in currentLab2Drp: - if obj.label == self.movingID: - self.addObjContourToContoursImage(obj=obj, ax=ax) - break + moveMask = np.ascontiguousarray( + (self.currentLab2D == self.movingID), dtype=np.uint8 + ) + contours = core.get_obj_contours( + obj_image=moveMask, obj_bbox=None, all_external=True, local=True + ) + imageItem = self.getContoursImageItem(ax) + if imageItem is not None: + thickness = self.contLineWeight + color = self.contLineColor + cv2.drawContours(self.contoursImage, contours, -1, color, thickness) + imageItem.setImage(self.contoursImage) elif how.find('overlay segm. masks') != -1: if ax == 0: self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False) @@ -26611,7 +27069,8 @@ def updateCcaDfDeletedIDsTimelapse( else: for delID in deletedIDs: dataDict = posData.allData_li[fut_frame_i] - delIDexists = dataDict['IDs_idxs'].get(delID, False) + rp = dataDict['regionprops'] + delIDexists = delID in rp.IDs if not delIDexists: continue @@ -26648,7 +27107,8 @@ def updateCcaDfDeletedIDsTimelapse( else: for delID in deletedIDs: dataDict = posData.allData_li[past_frame_i] - delIDexists = dataDict['IDs_idxs'].get(delID, False) + rp = dataDict['regionprops'] + delIDexists = delID in rp.IDs if not delIDexists: continue @@ -26944,8 +27404,7 @@ def highlightHoverID(self, x, y, hoverID=None): return posData = self.data[self.pos_i] - objIdx = posData.IDs_idxs[hoverID] - obj = posData.rp[objIdx] + obj = posData.rp.get_obj_from_ID(hoverID) self.goToZsliceSearchedID(obj) self.highlightSearchedID(hoverID) @@ -27009,12 +27468,10 @@ def highlightHoverIDsKeptObj(self, x, y, hoverID=None): return posData = self.data[self.pos_i] - try: - objIdx = posData.IDs_idxs[hoverID] - except KeyError as err: - return + obj = posData.rp.get_obj_from_ID(hoverID, warn=False) + if obj is None: + return - obj = posData.rp[objIdx] self.goToZsliceSearchedID(obj) for ID in self.keptObjectsIDs: @@ -27079,11 +27536,10 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True): self.highlightedID = ID self.highlightIDToolbar.setVisible(True) - objIdx = posData.IDs_idxs.get(ID) - if objIdx is None: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: return - obj = posData.rp[objIdx] isObjVisible = self.isObjVisible(obj.bbox) if not isObjVisible: return @@ -27111,7 +27567,11 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True): self.highLightIDLayerImg1.setImage(self.highlightedLab) self.labelsLayerImg1.setOpacity(alpha/3) else: - contours = self.getObjContours(obj, all_external=True) + contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for cont in contours: self.searchedIDitemLeft.addPoints(cont[:,0]+0.5, cont[:,1]+0.5) @@ -27121,7 +27581,11 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True): self.labelsLayerRightImg.setOpacity(alpha/3) else: if contours is None: - contours = self.getObjContours(obj, all_external=True) + contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for cont in contours: self.searchedIDitemRight.addPoints(cont[:,0]+0.5, cont[:,1]+0.5) @@ -27314,8 +27778,14 @@ def setManualBackgroundImage(self): self.initManualBackgroundImage() contours = [] - for obj in skimage.measure.regionprops(posData.manualBackgroundLab): - obj_contours = self.getObjContours(obj, all_external=True) + for obj in regionprops.acdcRegionprops( + posData.manualBackgroundLab, precache_centroids=False + ): + obj_contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) contours.extend(obj_contours) textItem = self.manualBackgroundTextItems[obj.label] textItem.setText(f'{obj.label}') @@ -27331,7 +27801,7 @@ def setManualBackgroundImage(self): def setManualBackgrounNextID(self): posData = self.data[self.pos_i] currentID = self.manualBackgroundObj.label - idx = posData.IDs_idxs[currentID] + idx = posData.rp.ID_to_idx[currentID] next_idx = idx + 1 if next_idx >= len(posData.IDs): return @@ -27390,7 +27860,9 @@ def setManualBackgroundLab(self, load_from_store=False, debug=True): if posData.manualBackgroundLab is None: self.initManualBackgroundImage() - for obj in skimage.measure.regionprops(posData.manualBackgroundLab): + for obj in regionprops.acdcRegionprops( + posData.manualBackgroundLab, precache_centroids=False + ): textItem = pg.TextItem(text='', color='r', anchor=(0.5, 0.5)) if obj.label in self.manualBackgroundTextItems: continue @@ -27407,13 +27879,26 @@ def updateContoursImage(self, ax, delROIsIDs=None, compute=True): self.contoursImage[:] = 0 contours = [] - for obj in skimage.measure.regionprops(self.currentLab2D): + posData = self.data[self.pos_i] + rp = posData.rp + use_local_rp = ( + self.isSegm3D + and self.zProjComboBox.currentText() == 'single z-slice' + ) + if rp is None: + lab = self.currentLab2D + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) + if not use_local_rp: + posData.rp = rp + elif use_local_rp and rp.is3D: + rp = self.rpCurr2D() + + for obj in rp: obj_contours = self.getObjContours( - obj, - all_external=True, - force_calc=compute, + obj, + all_external=True, include_internal=self.showAllContoursToggle.isChecked() - ) + ) contours.extend(obj_contours) thickness = self.contLineWeight @@ -27426,17 +27911,18 @@ def setContoursImage(self, imageItem, contours, thickness, color): def getObjFromID(self, ID): posData = self.data[self.pos_i] - try: - idx = posData.IDs_idxs[ID] - except KeyError as e: + obj = posData.rp.get_obj_from_ID(ID, warn=False) + if obj is None: # Object already cleared return - - obj = posData.rp[idx] return obj def setLostObjectContour(self, obj): - allContours = self.getObjContours(obj, all_external=True) + allContours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for objContours in allContours: xx = objContours[:,0] + 0.5 yy = objContours[:,1] + 0.5 @@ -27448,7 +27934,11 @@ def setTrackedLostObjectContour(self, obj): if self.isExportingVideo: return - allContours = self.getObjContours(obj, all_external=True) + allContours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for objContours in allContours: xx = objContours[:,0] + 0.5 yy = objContours[:,1] + 0.5 @@ -27472,7 +27962,6 @@ def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None): posData = self.data[self.pos_i] prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] if posData.whitelist is not None and posData.whitelist.whitelistIDs is not None: whitelist = posData.whitelist.whitelistIDs[posData.frame_i-1] else: @@ -27483,11 +27972,15 @@ def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None): if lostID in delROIsIDs or (whitelist is not None and lostID not in whitelist): continue - obj = prev_rp[prev_IDs_idxs[lostID]] + obj = prev_rp.get_obj_from_ID(lostID) if not self.isObjVisible(obj.bbox): continue - obj_contours = self.getObjContours(obj, all_external=True) + obj_contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) if ax == 0: self.addLostObjsToLostObjImage(obj, lostID) @@ -27528,17 +28021,20 @@ def updateLostTrackedContoursImage( tracked_lost_IDs = self.getTrackedLostIDs() prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] contours = [] for tracked_lost_ID in tracked_lost_IDs: if tracked_lost_ID in delROIsIDs: continue - obj = prev_rp[prev_IDs_idxs[tracked_lost_ID]] + obj = prev_rp.get_obj_from_ID(tracked_lost_ID) if not self.isObjVisible(obj.bbox): continue - obj_contours = self.getObjContours(obj, all_external=True) + obj_contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) contours.extend(obj_contours) self.drawLostTrackedObjContoursImage(imageItem, contours) @@ -27596,13 +28092,20 @@ def getNearestLostObjID(self, y, x): return nearest_ID def setCcaIssueContour(self, obj): - objContours = self.getObjContours(obj, all_external=True) + objContours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for cont in objContours: xx = cont[:,0] + 0.5 yy = cont[:,1] + 0.5 self.ax1_lostObjScatterItem.addPoints(xx, yy) + + posData = self.data[self.pos_i] self.textAnnot[0].addObjAnnotation( - obj, 'lost_object', f'{obj.label}?', False + obj, 'lost_object', f'{obj.label}?', False, + rp=posData.rp, getObjCentroidFunc=self.getObjCentroid ) def isLastVisitedAgainCca(self, curr_df, enforceAll=False): @@ -27650,7 +28153,8 @@ def highlightNewCellNotEnoughG1cells(self, IDsCellsG1): yy = objContours[:,1] + 0.5 self.ccaFailedScatterItem.addPoints(xx, yy) self.textAnnot[0].addObjAnnotation( - obj, 'green', f'{obj.label}?', False + obj, 'green', f'{obj.label}?', False, + rp=posData.rp, getObjCentroidFunc=self.getObjCentroid ) def handleNoCellsInG1(self, numCellsG1, numNewCells): @@ -27688,7 +28192,11 @@ def addObjContourToContoursImage( if obj is None: return - contours = self.getObjContours(obj, all_external=True) + contours = self.getObjContours( + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) if thickness is None: thickness = self.contLineWeight if color is None: @@ -27747,9 +28255,7 @@ def setAllTextAnnotations(self, labelsToSkip=None): return delROIsIDs def setAllContoursImages(self, delROIsIDs=None, compute=True): - if compute: - self.computeAllContours() - self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) + self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) #almost all from here self.updateContoursImage(ax=1, delROIsIDs=delROIsIDs, compute=compute) def setAllLostObjContoursImage(self, delROIsIDs=None): @@ -27875,7 +28381,6 @@ def keyDownCallback( QAbstractSlider.SliderAction.SliderSingleStepSub ) - # @exec_time @exception_handler def updateAllImages( self, image=None, computePointsLayers=True, computeContours=True, @@ -27957,22 +28462,19 @@ def deleteIDFromLab( lab = self.get_2Dlab(lab) if delMask is not None: delMask = self.get_2Dlab(delMask) - rp = skimage.measure.regionprops(lab) - IDs_idxs = {obj.label: idx for idx, obj in enumerate(rp)} - else: + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) + else: if frame_i==posData.frame_i: rp = posData.rp - IDs_idxs = posData.IDs_idxs else: rp = posData.allData_li[frame_i]['regionprops'] - IDs_idxs = posData.allData_li[frame_i]['IDs_idxs'] if isinstance(delID, int): delID = [delID] is_any_id_present = False for _delID in delID: - if _delID in IDs_idxs: + if _delID in rp.IDs: is_any_id_present = True break @@ -27985,10 +28487,9 @@ def deleteIDFromLab( delMask[:] = False for _delID in delID: - idx = IDs_idxs.get(_delID, None) - if idx is None: + if _delID not in rp.IDs: continue - obj = rp[idx] + obj = rp.get_obj_from_ID(_delID) delMask[obj.slice][obj.image] = True lab[delMask] = 0 @@ -28001,31 +28502,6 @@ def deleteIDFromLab( return lab, delMask - def removeStoredContours(self, delID, frame_i=None, z_slice=None): - posData = self.data[self.pos_i] - - if frame_i is None: - frame_i = posData.frame_i - - dataDict = posData.allData_li[posData.frame_i] - try: - newContours = {} - for key, contours in dataDict['contours'].items(): - ID = key[0] - if ID == delID: - continue - - if z_slice is not None: - z_slice_i = key[1] - if z_slice_i != z_slice: - continue - - newContours[key] = contours - - dataDict['contours'] = newContours - except KeyError as err: - pass - @disableWindow def deleteIDmiddleClick( self, delIDs: Iterable, applyFutFrames, includeUnvisited, @@ -28084,7 +28560,6 @@ def deleteIDmiddleClick( self.clearObjContour(ID=_delID, ax=1) if z_slice is None: self.removeObjectFromRp(_delID) - self.removeStoredContours(_delID, z_slice=z_slice) if shift and self.isSegm3D: self.update_rp() @@ -28129,9 +28604,13 @@ def setOverlayLabelsItems(self, specific=None): imageItem, contoursItem, gradItem = items contoursItem.clear() if drawMode == 'Draw contours': - for obj in skimage.measure.regionprops(ol_lab): + for obj in regionprops.acdcRegionprops( + ol_lab, precache_centroids=False + ): contours = self.getObjContours( - obj, all_external=True + obj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() ) for cont in contours: contoursItem.addPoints(cont[:,0]+0.5, cont[:,1]+0.5) @@ -28294,9 +28773,12 @@ def highlightHoverLostObj(self, modifiers, event): self.ax1_lostObjScatterItem.setData([], []) else: prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs'] - lostObj = prev_rp[prev_IDs_idxs[hoverLostID]] - obj_contours = self.getObjContours(lostObj, all_external=True) + lostObj = prev_rp.get_obj_from_ID(hoverLostID) + obj_contours = self.getObjContours( + lostObj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for cont in obj_contours: xx = cont[:,0] yy = cont[:,1] @@ -28316,10 +28798,12 @@ def getPrevFrameIDs(self, current_frame_i=None): if current_frame_i is None: return [] - prev_frame_i = current_frame_i - 1 - prevIDs = posData.allData_li[prev_frame_i]['IDs'] + if current_frame_i == 0: + return [] - if prevIDs: + prev_frame_i = current_frame_i - 1 + if posData.allData_li[prev_frame_i]['regionprops'] is not None: + prevIDs = posData.allData_li[prev_frame_i]['regionprops'].IDs return prevIDs # IDs in previous frame were not stored --> load prev lab from HDD @@ -28328,8 +28812,11 @@ def getPrevFrameIDs(self, current_frame_i=None): frame_i=prev_frame_i, return_copy=False ) - rp = skimage.measure.regionprops(prev_lab) - prevIDs = [obj.label for obj in rp] + rp = regionprops.acdcRegionprops( + prev_lab, precache_centroids=False + ) + posData.allData_li[prev_frame_i]['regionprops'] = rp + prevIDs = rp.IDs return prevIDs # @exec_time @@ -28490,7 +28977,9 @@ def separateByLabelling(self, lab, rp, maxID=None): maxID = max(posData.IDs, default=1) for obj in rp: lab_obj = skimage.measure.label(obj.image) - rp_lab_obj = skimage.measure.regionprops(lab_obj) + rp_lab_obj = regionprops.acdcRegionprops( + lab_obj, precache_centroids=False + ) if len(rp_lab_obj)<=1: continue lab_obj += maxID @@ -28544,105 +29033,86 @@ def trackManuallyAddedObject( added_IDs = [added_IDs] posData = self.data[self.pos_i] - tracked_lab = self.tracking( + tracked_lab, assignments = self.tracking( enforce=True, assign_unique_new_IDs=False, return_lab=True, - IDs=added_IDs + specific_IDs=added_IDs, return_assignments=True, + against_next=posData.frame_i==0 ) + + # RP not updated after tracking!!! self.clearAssignedObjsSecondStep() if tracked_lab is None: return # Track only new object - prevIDs = posData.allData_li[posData.frame_i-1]['IDs'] - - # mask = np.zeros(posData.lab.shape, dtype=bool) - update_rp = False - + prevIDs = posData.allData_li[posData.frame_i-1]['regionprops'].IDs + + # assignments_new = dict() + # self.update_rp(assignments=assignments) for added_ID in added_IDs: - # try: - # obj = posData.rp[added_ID] # ID not present - # mask[obj.slice][obj.image] = True - - # except IndexError as err: - mask = posData.lab == added_ID + + # check if added ID is already present + # here PR is "stale" so ID maps are not tracked + obj = posData.rp.get_obj_from_ID(added_ID, warn=False) + if obj is None: + continue try: - trackedID = tracked_lab[mask][0] + trackedID = tracked_lab[obj.slice][obj.image][0] except IndexError as err: # added_ID is not present continue isTrackedIDalreadyPresentAndNotNew = ( - posData.IDs_idxs.get(trackedID) is not None + posData.rp.ID_to_idx.get(trackedID) is not None and added_ID != trackedID ) if isTrackedIDalreadyPresentAndNotNew: + self.updatePointsLayerClickEntryTableEndname( + 'added obj already present', added_ID, trackedID + ) continue isTrackedIDinPrevIDs = trackedID in prevIDs if isTrackedIDinPrevIDs: - posData.lab[mask] = trackedID + posData.lab[obj.slice][obj.image] = trackedID else: # New object where we can try to track against next frame - trackedID = self.trackNewIDtoNewIDsFutureFrame(added_ID, mask) + trackedID, assignments = self.trackNewIDtoNewIDsFutureFrame(added_ID, obj, assignments) if trackedID is None: self.clearAssignedObjsSecondStep() continue - posData.lab[mask] = trackedID + posData.lab[obj.slice][obj.image] = trackedID self.keepOnlyNewIDAssignedObjsSecondStep(trackedID) - update_rp = True - if update_rp: - self.update_rp(wl_update=wl_update) - + self.update_rp(wl_update=wl_update, assignments=assignments) + def trackFrameCustomTracker( - self, prev_lab, currentLab, IDs=None, unique_ID=None + self, prev_lab, currentLab, specific_IDs=None, unique_ID=None, + return_assignments=True, dont_return_tracked_lab=False ): if unique_ID is None: unique_ID = self.setBrushID() - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - unique_ID=unique_ID, - IDs=IDs, - **self.track_frame_params, - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1: - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, IDs=IDs, - **self.track_frame_params - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'IDs\'') != -1: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - **self.track_frame_params) - else: - raise err - elif str(err).find('an unexpected keyword argument \'IDs\'') != -1: - try: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - unique_ID=unique_ID, - **self.track_frame_params - ) - except TypeError as err: - if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1: - tracked_result = self.realTimeTracker.track_frame( - prev_lab, currentLab, - **self.track_frame_params - ) - else: - raise err - else: - raise err + + kwargs_total = { + 'unique_ID': unique_ID, + 'return_assignments': return_assignments, + 'dont_return_tracked_lab': dont_return_tracked_lab, + 'specific_IDs': specific_IDs + } + kwargs_total.update(self.track_frame_params) + + kwargs = {k: v for k, v in kwargs_total.items() if k in self.realTimeTracker_kwargs} + tracked_result = self.realTimeTracker.track_frame( + prev_lab, currentLab, + **kwargs, + ) return tracked_result def trackFrame( self, prev_lab, prev_rp, curr_lab, curr_rp, curr_IDs, - assign_unique_new_IDs=True, IDs=None, unique_ID=None + assign_unique_new_IDs=True, specific_IDs=None, unique_ID=None, + dont_return_tracked_lab=False, return_assignments=False, ): if self.trackWithAcdcAction.isChecked(): tracked_result = CellACDC_tracker.track_frame( @@ -28651,8 +29121,10 @@ def trackFrame( setBrushID_func=self.setBrushID, posData=self.data[self.pos_i], assign_unique_new_IDs=assign_unique_new_IDs, - IDs=IDs, - unique_ID=unique_ID + specific_IDs=specific_IDs, + unique_ID=unique_ID, + return_assignments=return_assignments, + dont_return_tracked_lab=dont_return_tracked_lab ) elif self.trackWithYeazAction.isChecked(): tracked_result = self.tracking_yeaz.correspondence( @@ -28661,17 +29133,43 @@ def trackFrame( ) else: tracked_result = self.trackFrameCustomTracker( - prev_lab, curr_lab, IDs=IDs, unique_ID=unique_ID + prev_lab, curr_lab, specific_IDs=specific_IDs, unique_ID=unique_ID, + dont_return_tracked_lab=dont_return_tracked_lab, return_assignments=return_assignments ) # Check if tracker also returns additional info + assignments = None if isinstance(tracked_result, tuple): - tracked_lab, tracked_lost_IDs = tracked_result - self.handleAdditionalInfoRealTimeTracker(prev_rp, tracked_lost_IDs) + tracked_lab, add_info = tracked_result + assignments = self.handleAdditionalInfoRealTimeTracker( + prev_rp, add_info) + elif isinstance(tracked_result, dict) and dont_return_tracked_lab: + add_info = tracked_result + if 'assignments' in add_info: # if still entire add_info is returned + assignments = self.handleAdditionalInfoRealTimeTracker( + prev_rp, add_info) + else: + assignments = add_info # its just assignements else: tracked_lab = tracked_result - return tracked_lab + if not return_assignments and not dont_return_tracked_lab: + return tracked_lab + + # get assignments + if assignments is None: + assignments = dict() + for obj in curr_rp: + try: + old_lab = obj.label + new_lab = tracked_lab[obj.slice][obj.image][0] + assignments[old_lab] = new_lab + except: + import pdb; pdb.set_trace() + + if dont_return_tracked_lab: + return assignments + return tracked_lab, assignments def clearAssignedObjsSecondStep(self): posData = self.data[self.pos_i] @@ -28681,37 +29179,32 @@ def trackSubsetIDs(self, subsetIDs: Iterable[int]): posData = self.data[self.pos_i] if posData.frame_i == 0: return - - subsetLab = np.zeros_like(posData.lab) - for subsetID in subsetIDs: - subsetLab[posData.lab == subsetID] = subsetID prev_lab = posData.allData_li[posData.frame_i-1]['labels'] prev_rp = posData.allData_li[posData.frame_i-1]['regionprops'] - tracked_lab = self.trackFrame( + assignments = self.trackFrame( prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs, - assign_unique_new_IDs=True - ) - doUpdateRp = False - for subsetID in subsetIDs: - subsetIDmask = posData.lab == subsetID - trackedID = tracked_lab[subsetIDmask][0] - if trackedID == subsetID: - continue - - is_manually_edited = False - for y, x, new_ID in posData.editID_info: - if new_ID == subsetID: + assign_unique_new_IDs=True, specific_IDs=subsetIDs, + dont_return_tracked_lab=True + ) + # I think assignments already avoids merging + assignments_new = dict() + for old_ID, new_ID in assignments.items(): + # get "old" id based on assignments + if old_ID == new_ID: + continue # nothing to do + + for y, x, editID in posData.editID_info: + if editID == old_ID or editID == new_ID: # Do not track because it was manually edited - break + continue - posData.lab[subsetIDmask] = tracked_lab[subsetIDmask] - doUpdateRp = True - - if not doUpdateRp: - return + + obj = posData.rp.get_obj_from_ID(old_ID) # pr is still old, so we need to get the old ID + posData.lab[obj.slice][obj.image] = new_ID + assignments_new[old_ID] = new_ID # old ID : new tracked ID - self.update_rp() + self.update_rp(assignments=assignments_new) def doSkipTracking(self, against_next: bool, enforce: bool): if self.isSnapshot: @@ -28760,13 +29253,13 @@ def tracking( storeUndo=False, prev_lab=None, prev_rp=None, return_lab=False, assign_unique_new_IDs=True, separateByLabel=True, wl_update=True, - IDs=None, against_next=False, + against_next=False, specific_IDs=None , return_assignments=False ): posData = self.data[self.pos_i] - + return_tuple = (None, None) if return_assignments and return_lab else None if self.doSkipTracking(against_next, enforce): self.setLostNewOldPrevIDs() - return + return return_tuple """Tracking starts here""" staturBarLabelText = self.statusBarLabel.text() @@ -28800,41 +29293,54 @@ def tracking( if posData.frame_i < self.get_last_tracked_i(): unique_ID = self.setBrushID(return_val=True) - tracked_lab = self.trackFrame( + tracked_lab, assignments = self.trackFrame( prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs, - assign_unique_new_IDs=assign_unique_new_IDs, IDs=IDs, - unique_ID=unique_ID + assign_unique_new_IDs=assign_unique_new_IDs, + unique_ID=unique_ID, specific_IDs=specific_IDs, + return_assignments=True ) if DoManualEdit: # Correct tracking with manually changed IDs - rp = skimage.measure.regionprops(tracked_lab) - IDs = [obj.label for obj in rp] - self.manuallyEditTracking(tracked_lab, IDs) + tracked_lab, assignments = self.manuallyEditTracking(tracked_lab, assignments) if return_lab: QTimer.singleShot(50, partial( self.statusBarLabel.setText, staturBarLabelText )) + if return_assignments: + return tracked_lab, assignments return tracked_lab # Update labels, regionprops and determine new and lost IDs posData.lab = tracked_lab - self.update_rp(wl_update=wl_update, ) + self.update_rp(wl_update=wl_update, assignments=assignments) self.setAllTextAnnotations() QTimer.singleShot(50, partial( self.statusBarLabel.setText, staturBarLabelText )) + if return_assignments and return_lab: + return tracked_lab, assignments + elif return_assignments: + return assignments + elif return_lab: + return tracked_lab - def handleAdditionalInfoRealTimeTracker(self, prev_rp, *args): + def handleAdditionalInfoRealTimeTracker(self, prev_rp, add_info): + assignments = None if self._rtTrackerName == 'CellACDC_normal_division': - tracked_lost_IDs = args[0] + tracked_lost_IDs = add_info['mothers'] self.setTrackedLostCentroids(prev_rp, tracked_lost_IDs) + assignments = add_info['assignments'] elif self._rtTrackerName == 'CellACDC_2steps': - if args[0] is None: - return - posData = self.data[self.pos_i] - posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = args[0] + assignments = add_info['assignments'] + if add_info['to_track_tracked_objs_2nd_step'] is not None: + posData = self.data[self.pos_i] + posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = add_info['to_track_tracked_objs_2nd_step'] + elif self._rtTrackerName == 'Cell-ACDC': + assignments = add_info['assignments'] + + return assignments def keepOnlyNewIDAssignedObjsSecondStep(self, trackedID): posData = self.data[self.pos_i] @@ -28895,7 +29401,11 @@ def annotateAssignedObjsAcdcTrackerSecondStep(self): new_objs_1st_step, lost_objs_1st_step = annotInfo for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step): - allContours = self.getObjContours(lostObj, all_external=True) + allContours = self.getObjContours( + lostObj, + all_external=True, + include_internal=self.showAllContoursToggle.isChecked() + ) for objContours in allContours: isObjVisible = self.isObjVisible(newObj.bbox) if not isObjVisible: @@ -28930,12 +29440,24 @@ def setTrackedLostCentroids(self, prev_rp, tracked_lost_IDs): """ posData = self.data[self.pos_i] frame_i = posData.frame_i + prev_lab = posData.allData_li[frame_i-1]['labels'] for obj in prev_rp: if obj.label not in tracked_lost_IDs: continue - - int_centroid = tuple([int(val) for val in obj.centroid]) + if isinstance(prev_rp, regionprops.acdcRegionprops): + ID = obj.ID + centroid = prev_rp.get_centroid(ID, exact=True) + else: + centroid = obj.centroid + int_centroid = tuple([int(val) for val in centroid]) + # check if centroid has right ID + if prev_lab[int_centroid] != ID: + # get closest point with the right ID + coords = obj.coords + distances = np.sqrt(np.sum((coords - centroid) ** 2, axis=1)) + closest_idx = np.argmin(distances) + int_centroid = tuple([int(val) for val in coords[closest_idx]]) try: posData.tracked_lost_centroids[frame_i].add(int_centroid) except KeyError: @@ -28989,28 +29511,59 @@ def getTrackedLostIDs(self, prev_lab=None, IDs_in_frames=None, frame_i=None): posData.trackedLostIDs = trackedLostIDs return trackedLostIDs - - def manuallyEditTracking(self, tracked_lab, allIDs): + + def manuallyEditTracking(self, tracked_lab, assignments): posData = self.data[self.pos_i] infoToRemove = [] - # Correct tracking with manually changed IDs - maxID = max(allIDs, default=1) - for y, x, new_ID in posData.editID_info: - old_ID = tracked_lab[y, x] - if old_ID == 0 or old_ID == new_ID: - infoToRemove.append((y, x, new_ID)) + + if not assignments: + return tracked_lab, assignments + + # !!! RP is stale so we need to reverse search for the ID + reversed_assignments = ( + {tracked_id: stale_id for stale_id, tracked_id in assignments.items()} + if assignments else {} + ) + stale_ids = set(posData.rp.IDs) + + covered_edited_IDs = set() + for y, x, edited_ID in posData.editID_info: + new_ID = assignments.get(edited_ID, edited_ID) # ID in tracked lab + if new_ID in covered_edited_IDs: + # This ID has already been edited by sawpping for example continue - if new_ID in allIDs: - tempID = maxID+1 - tracked_lab[tracked_lab == old_ID] = tempID - tracked_lab[tracked_lab == new_ID] = old_ID - tracked_lab[tracked_lab == tempID] = new_ID + + if new_ID == 0 or new_ID == edited_ID: # edited ID is not tracked to a different ID + infoToRemove.append((y, x, edited_ID)) + continue + + old_RP_ID = reversed_assignments.get(edited_ID, edited_ID) # ID pre tracking + old_obj = posData.rp.get_obj_from_ID(old_RP_ID) # obj pre tracking + + if edited_ID in stale_ids: + # a swap has been made by the user between an old ID (old_RP_ID) and a new ID (edited_ID) + new_obj = posData.rp.get_obj_from_ID(edited_ID) + tracked_lab[old_obj.slice][old_obj.image] = edited_ID + tracked_lab[new_obj.slice][new_obj.image] = old_RP_ID + # update assignemnets + assignments[old_RP_ID] = edited_ID + assignments[edited_ID] = old_RP_ID + # add the two swapped IDs + + covered_edited_IDs.add(edited_ID) + covered_edited_IDs.add(old_RP_ID) + else: - tracked_lab[tracked_lab == old_ID] = new_ID - if new_ID > maxID: - maxID = new_ID + tracked_lab[old_obj.slice][old_obj.image] = edited_ID + + assignments[old_RP_ID] = edited_ID + + covered_edited_IDs.add(edited_ID) + for info in infoToRemove: posData.editID_info.remove(info) + + return tracked_lab, assignments def warnReinitLastSegmFrame(self): current_frame_n = self.navigateScrollBar.value() @@ -30243,8 +30796,12 @@ def initRealTimeTracker(self, force=False): rtTracker = aliases[rtTracker] if rtTracker == 'Cell-ACDC': + self._rtTrackerName = 'Cell-ACDC' + self.realTimeTracker_kwargs = None # This is hard coded return if rtTracker == 'YeaZ': + self._rtTrackerName = 'YeaZ' + self.realTimeTracker_kwargs = None # This is hard coded return if self.isRealTimeTrackerInitialized and not force: @@ -30262,6 +30819,8 @@ def initRealTimeTracker(self, force=False): self.realTimeTracker = realTimeTracker self.track_frame_params = track_frame_params + self.realTimeTracker_kwargs = inspect.signature( + self.realTimeTracker.track_frame).parameters self.logger.info(f'{rtTracker} tracker successfully initialized.') if 'image_channel_name' in self.track_frame_params: # Remove the channel name since it was already loaded in init_tracker @@ -30755,6 +31314,44 @@ def annotOptionClickedRight( self.setDrawAnnotComboboxTextRight(saveSettings=saveSettings) + def checkHandleTooManyNewItems(self): + posData = self.data[self.pos_i] + num_objects = len(posData.rp) + if num_objects < 1500: + return True + + out = _warnings.warnTooManyNewItems(self, num_objects, self) + cancel, switchToLowRes, deactivateAnnot = out + if cancel: + return False + + if switchToLowRes: + self.highLowResAction.setChecked(False) + self.changeTextResolution() + return True + + if deactivateAnnot: + self.annotCcaInfoCheckbox.blockSignals(True) + self.annotIDsCheckbox.blockSignals(True) + self.annotCcaInfoCheckbox.setChecked(False) + self.annotIDsCheckbox.setChecked(False) + self.annotCcaInfoCheckbox.blockSignals(False) + self.annotIDsCheckbox.blockSignals(False) + + self.annotCcaInfoCheckboxRight.blockSignals(True) + self.annotIDsCheckboxRight.blockSignals(True) + self.annotCcaInfoCheckboxRight.setChecked(False) + self.annotIDsCheckboxRight.setChecked(False) + self.annotCcaInfoCheckboxRight.blockSignals(False) + self.annotIDsCheckboxRight.blockSignals(False) + + self.textAnnot[0].setCcaAnnot(False) + self.textAnnot[0].setLabelAnnot(False) + self.textAnnot[1].setCcaAnnot(False) + self.textAnnot[1].setLabelAnnot(False) + return True + + def setAnnotOptionsCcaMode(self): self.prevAnnotOptions = self.storeCurrentAnnotOptions_ax1( return_value=True @@ -32010,7 +32607,9 @@ def getZoomIDs(self, viewRange=None): ) zoomLab = skimage.segmentation.clear_border(lab[zoomSlice]) - zoomRp = skimage.measure.regionprops(zoomLab) + zoomRp = regionprops.acdcRegionprops( + zoomLab, precache_centroids=False + ) zoomIDs = [obj.label for obj in zoomRp] return zoomIDs diff --git a/cellacdc/load.py b/cellacdc/load.py index 62a336464..5c6f634a8 100755 --- a/cellacdc/load.py +++ b/cellacdc/load.py @@ -40,7 +40,7 @@ from . import io from . import core from . import IMAGE_EXTENSIONS, VIDEO_EXTENSIONS - +from . import fonts from . import GUI_INSTALLED if GUI_INSTALLED: @@ -1306,6 +1306,7 @@ def __init__( self.loadLastEntriesMetadata() self.attempFixBasenameBug() self.non_aligned_ext = '.tif' + self.segmMetadata = None if filename_ext.endswith('aligned.npz'): for file in myutils.listdir(self.images_path): if file.endswith(f'{user_ch_name}.h5'): @@ -1646,7 +1647,13 @@ def countObjectsInSegmTimelapse(self, categories: set[str] | list[str]): for frame_i in range(len(self.segm_data)): lab = self.allData_li[frame_i]['labels'] if lab is not None: - IDsFrame = self.allData_li[frame_i]['IDs'] + if hasattr(self.allData_li[frame_i]['regionprops'], 'IDs'): + IDsFrame = self.allData_li[frame_i]['regionprops'].IDs + else: + IDsFrame = [ + obj.label + for obj in self.allData_li[frame_i]['regionprops'] + ] if uniqueIDsVisited is not None: uniqueIDsVisited.update(IDsFrame) @@ -1866,6 +1873,7 @@ def loadOtherFiles( new_endname='', labelBoolSegm=None, load_whitelistIDs=False, + load_segm_info_ini=False ): self.segmFound = False if load_segm_data else None self.acdc_df_found = False if load_acdc_df else None @@ -2069,6 +2077,9 @@ def loadOtherFiles( if load_whitelistIDs: self.loadWhitelist() + + if load_segm_info_ini: + self.readSegmMetadataIni() def checkAndFixZsliceSegmInfo(self): if not hasattr(self, 'segmInfo_df'): @@ -2322,14 +2333,14 @@ def fromTrackerToAcdcDf( rp = skimage.measure.regionprops(lab) for obj in rp: centroid = obj.centroid - yc, xc = obj.centroid[-2:] + yc, xc = centroid[-2:] acdc_df.at[(frame_i, obj.label), 'x_centroid'] = int(xc) acdc_df.at[(frame_i, obj.label), 'y_centroid'] = int(yc) if len(centroid) == 3: if 'z_centroid' not in acdc_df.columns: acdc_df['z_centroid'] = 0 - zc = obj.centroid[0] + zc = centroid[0] acdc_df.at[(frame_i, obj.label), 'z_centroid'] = int(zc) if not save: @@ -3059,6 +3070,7 @@ def buildPaths(self): self.raw_postproc_segm_path = f'{base_path}segm_raw_postproc' self.post_proc_mot_metrics = f'{base_path}post_proc_mot_metrics' self.segm_hyperparams_ini_path = f'{base_path}segm_hyperparams.ini' + self.segm_metadata_ini_path = f'{base_path}segm_metadata_data.ini' self.custom_annot_json_path = f'{base_path}custom_annot_params.json' self.custom_combine_metrics_path = ( f'{base_path}custom_combine_metrics.ini' @@ -3082,6 +3094,7 @@ def get_tracker_export_path(self, trackerName, ext): def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX): if not hasattr(self, 'img_data'): self.segm_data = None + self.single_timepoint_size = None return Y, X = self.img_data.shape[-2:] @@ -3093,7 +3106,16 @@ def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX): elif SizeT > 1: self.segm_data = np.zeros((SizeT, Y, X), int) else: - self.segm_data = np.zeros((Y, X), int) + self.segm_data = np.zeros((Y, X), int) + + def getSingleTimepointSegmSize(self): + if hasattr(self, 'single_timepoint_size') and self.single_timepoint_size is not None: + return self.single_timepoint_size + if self.SizeT > 1: + self.single_timepoint_size = np.prod(self.segm_data.shape[1:]) + else: # not sure if time axis is present but would be 1 anyways + self.single_timepoint_size = np.prod(self.segm_data.shape) + return self.single_timepoint_size def loadAllImgPaths(self): tif_paths = [] @@ -3185,7 +3207,7 @@ def askInputMetadata( self.SizeT, self.SizeZ, self.TimeIncrement, self.PhysicalSizeZ, self.PhysicalSizeY, self.PhysicalSizeX, ask_SizeT, ask_TimeIncrement, ask_PhysicalSizes, - parent=self.parent, font=apps.font, imgDataShape=self.img_data_shape, + parent=self.parent, font=fonts.font, imgDataShape=self.img_data_shape, posData=self, singlePos=singlePos, askSegm3D=askSegm3D, additionalValues=self._additionalMetadataValues, forceEnableAskSegm3D=forceEnableAskSegm3D, @@ -3421,7 +3443,15 @@ def loadWhitelist(self): self.whitelist = whitelist.Whitelist( total_frames=self.SizeT, ) - whitelist_path = self.segm_npz_path.replace('.npz', '_whitelistIDs.json') + whitelist_path_legacy = self.segm_npz_path.replace( + '.npz', '_whitelistIDs.json') + segm_filename = os.path.basename(self.segm_npz_path).replace('.npz', '') + segm_add_data_folder = os.path.join(self.images_path, segm_filename) + os.makedirs(segm_add_data_folder, exist_ok=True) + whitelist_path = os.path.join(segm_add_data_folder, 'whitelistIDs.json') + if os.path.exists(whitelist_path_legacy): + # move to new path + shutil.move(whitelist_path_legacy, whitelist_path) new_centroids_path = self.segm_npz_path.replace('.npz', '_new_centroids.json') success = self.whitelist.load( whitelist_path, new_centroids_path, self.segm_data, self.allData_li, @@ -3432,7 +3462,96 @@ def loadWhitelist(self): if not success: self.whitelist = None - + def readSegmMetadataIni(self): + if not os.path.exists(self.segm_metadata_ini_path): + return None + + cp = config.ConfigParser() + cp.read(self.segm_metadata_ini_path) + # one entry for each segmentation file + self.segmMetadata = {} + for segm_file in cp.sections(): + sizeX = cp.getint(segm_file, 'sizeX', fallback=None) + sizeY = cp.getint(segm_file, 'sizeY', fallback=None) + sizeT = cp.getint(segm_file, 'SizeT', fallback=None) + sizeZ = cp.getint(segm_file, 'SizeZ', fallback=None) + is_3D = sizeZ > 1 if sizeZ is not None else False + last_modified_date = cp.get(segm_file, 'last_modified_date', fallback=None) + acdc_df_segm = cp.get(segm_file, 'acdc_df_segm', fallback=None) + acdc_df_save_date = cp.get(segm_file, 'acdc_df_save_date', fallback=None) + self.segmMetadata[segm_file] = { + 'SizeT': sizeT, + 'SizeZ': sizeZ, + 'is_3D': is_3D, + 'last_modified_date': last_modified_date, + 'acdc_df_segm': acdc_df_segm, + 'acdc_df_save_date': acdc_df_save_date, + 'sizeX': sizeX, + 'sizeY': sizeY, + } + + def saveSegmMetadataIni(self): + # need to be called in more locations, will be full yimplemented in workflow gui + cp = config.ConfigParser() + for segm_file, metadata in self.segmMetadata.items(): + cp[segm_file] = {} + cp[segm_file]['SizeT'] = str(metadata.get('SizeT', '')) + cp[segm_file]['SizeZ'] = str(metadata.get('SizeZ', '')) + cp[segm_file]['last_modified_date'] = str(metadata.get('last_modified_date', '')) + cp[segm_file]['acdc_df_segm'] = str(metadata.get('acdc_df_segm', '')) + cp[segm_file]['sizeX'] = str(metadata.get('sizeX', '')) + cp[segm_file]['sizeY'] = str(metadata.get('sizeY', '')) + cp[segm_file]['acdc_df_save_date'] = str( + metadata.get('acdc_df_save_date', '') + ) + + with open(self.segm_metadata_ini_path, 'w') as configfile: + cp.write(configfile) + + def updateSegmMetadata(self, segm_file=None, SizeT=None, SizeZ=None, + acdc_df_segm=None, last_modified_date=None, + sizeY=None, sizeX=None, all=False, acdc_df_save_date=None): + if segm_file is None: + segm_file = os.path.basename(self.segm_npz_path) + + if self.segmMetadata is None: + self.segmMetadata = {} + segm_metadata = self.segmMetadata.get(segm_file, {}) + if SizeT is not None or all: + if SizeT is True or SizeT is None: + SizeT = self.SizeT + segm_metadata['SizeT'] = SizeT + if SizeZ is not None or all: + if SizeZ is True or SizeZ is None: + SizeZ = self.SizeZ if self.isSegm3D else 1 + segm_metadata['SizeZ'] = SizeZ + segm_metadata['is_3D'] = SizeZ > 1 + if acdc_df_segm is not None or all: + if acdc_df_segm is True or acdc_df_segm is None: + acdc_df_segm = os.path.basename(self.acdc_output_csv_path) # for future if we allow multpiple outputs + # clear other segm metadata entries with acdc_df info to avoid confusion + for info in self.segmMetadata.values(): + if info.get('acdc_df_segm', '') == acdc_df_segm: + info['acdc_df_segm'] = None + segm_metadata['acdc_df_segm'] = acdc_df_segm + if last_modified_date is not None or all: + if last_modified_date is True or last_modified_date is None: # explicitly in this cane set curr datetime + last_modified_date = datetime.now() + segm_metadata['last_modified_date'] = last_modified_date + if sizeY is not None or all: + if sizeY is True or sizeY is None: + sizeY = self.SizeY + segm_metadata['sizeY'] = sizeY + if sizeX is not None or all: + if sizeX is True or sizeX is None: + sizeX = self.SizeX + segm_metadata['sizeX'] = sizeX + if acdc_df_save_date is not None or all: + if acdc_df_save_date is True or acdc_df_save_date is None: + acdc_df_save_date = datetime.now() + segm_metadata['acdc_df_save_date'] = acdc_df_save_date + self.segmMetadata[segm_file] = segm_metadata + class select_exp_folder: def __init__(self): self.exp_path = None diff --git a/cellacdc/myutils.py b/cellacdc/myutils.py index f2695bcd3..51ae24e98 100644 --- a/cellacdc/myutils.py +++ b/cellacdc/myutils.py @@ -55,6 +55,7 @@ from . import urls from . import qrc_resources_path from . import settings_folderpath +from . import regionprops from .models._cellpose_base import min_target_versions_cp if GUI_INSTALLED: @@ -344,6 +345,7 @@ def __init__( level=logging.DEBUG ): super().__init__(f'{name}-{module}', level=level) + self.propagate = False # prevent UnicodeEncodeError via root StreamHandler self._stdout = sys.stdout self._stderr = StdErr(logger=self) sys.stderr = self._stderr @@ -367,8 +369,13 @@ def write(self, text, log_to_file=True, write_to_stdout=True): log_to_file : bool, optional If True, call `info` method with `text`. Default is True """ - if write_to_stdout: - self._stdout.write(text) + if write_to_stdout: + try: + self._stdout.write(text) + except UnicodeEncodeError: + self._stdout.write(text.encode( + self._stdout.encoding, errors='replace' + ).decode(self._stdout.encoding)) if not log_to_file: return @@ -378,8 +385,11 @@ def write(self, text, log_to_file=True, write_to_stdout=True): if not text: return - - self.debug(text) + + try: + self.debug(text) + except UnicodeEncodeError: + self.debug(text.encode('ascii', errors='replace').decode('ascii')) def close(self): for handler in self.handlers: @@ -614,7 +624,7 @@ def setupLogger(module='base', logs_path=None, caller='Cell-ACDC'): log_filename = f'{date_time}_{module}_{id}_stdout.log' log_path = os.path.join(logs_path, log_filename) - output_file_handler = logging.FileHandler(log_path, mode='w') + output_file_handler = logging.FileHandler(log_path, mode='w', encoding='utf-8') # Format your logs (optional) formatter = logging.Formatter( @@ -1033,22 +1043,6 @@ def showInExplorer(path): else: os.startfile(path) -def exec_time(func): - @wraps(func) - def inner_function(self, *args, **kwargs): - t0 = time.perf_counter() - if func.__code__.co_argcount==1 and func.__defaults__ is None: - result = func(self) - elif func.__code__.co_argcount>1 and func.__defaults__ is None: - result = func(self, *args) - else: - result = func(self, *args, **kwargs) - t1 = time.perf_counter() - s = f'{func.__name__} execution time = {(t1-t0)*1000:.3f} ms' - printl(s, is_decorator=True) - return result - return inner_function - def setRetainSizePolicy(widget, retain=True): sp = widget.sizePolicy() sp.setRetainSizeWhenHidden(retain) @@ -1111,22 +1105,61 @@ def get_chname_from_basename(filename, basename, remove_ext=True): chName = chName[:aligned_idx] return chName +def _edge_ids_2d(lab): + border_labels = np.r_[ + lab[0, :], + lab[-1, :], + lab[:, 0], + lab[:, -1], + ] + return np.unique(border_labels[border_labels != 0]) + +def _edge_ids_3d(lab): + face_labels = np.r_[ + lab[ 0, :, :].ravel(), # z min + lab[-1, :, :].ravel(), # z max + lab[:, 0, :].ravel(), # y min + lab[:, -1, :].ravel(), # y max + lab[:, :, 0].ravel(), # x min + lab[:, :, -1].ravel(), # x max + ] + ids = np.unique(face_labels) + return ids[ids != 0] + +def get_edge_ids(lab): + if lab.ndim == 2: + return _edge_ids_2d(lab) + elif lab.ndim == 3: + return _edge_ids_3d(lab) + else: + raise ValueError('Label array must be either 2D or 3D.') + +def clear_border(lab, return_edge_ids=False): + # probably faster than skimage since it avoids relabeling... + # assumes continous unique IDs, which we have. Modifies inplace! + edge_ids = get_edge_ids(lab) + lab[np.isin(lab, edge_ids)] = 0 + if return_edge_ids: + return edge_ids + def getBaseAcdcDf(rp): zeros_list = [0]*len(rp) nones_list = [None]*len(rp) minus1_list = [-1]*len(rp) - IDs = [] - xx_centroid = [] - yy_centroid = [] - zz_centroid = [] - for obj in rp: - xc, yc = obj.centroid[-2:] - IDs.append(obj.label) - xx_centroid.append(xc) - yy_centroid.append(yc) - if len(obj.centroid) == 3: - zc = obj.centroid[0] - zz_centroid.append(zc) + IDs = [0]*len(rp) + xx_centroid = [0]*len(rp) + yy_centroid = [0]*len(rp) + zz_centroid = [0]*len(rp) + + for i, obj in enumerate(rp): + centroid = obj.centroid + xc, yc = centroid[-2:] + IDs[i] = obj.label + xx_centroid[i] = xc + yy_centroid[i] = yc + if len(centroid) == 3: + zc = centroid[0] + zz_centroid[i] = zc df = pd.DataFrame( { @@ -1138,7 +1171,7 @@ def getBaseAcdcDf(rp): 'was_manually_edited': minus1_list } ).set_index('Cell_ID') - if zz_centroid: + if len(centroid) == 3: df['z_centroid'] = zz_centroid return df @@ -1705,7 +1738,7 @@ def download_java(): def get_model_path(model_name, create_temp_dir=True): if model_name == 'Automatic thresholding': - model_name == 'thresholding' + model_name = 'thresholding' model_info_path = os.path.join(cellacdc_path, 'models', model_name, 'model') @@ -2337,7 +2370,7 @@ def lab2d_to_rois(ImagejRoi, lab2D, ndigits, t=None, z=None): rp = skimage.measure.regionprops(lab2D) rois = [] for obj in rp: - cont = core.get_obj_contours(obj) + cont = core.get_obj_contours(obj=obj) yc, xc = obj.centroid x_str = str((int(xc))).zfill(ndigits) y_str = str((int(yc))).zfill(ndigits) @@ -4113,13 +4146,31 @@ def init_tracker( return tracker, track_params def import_segment_module(model_name): + original_model_name = model_name + if model_name == 'Automatic thresholding': + model_name = 'thresholding' + try: acdcSegment = import_module(f'cellacdc.models.{model_name}.acdcSegment') except ModuleNotFoundError as e: + # Do not mask missing dependencies imported by the module itself. + expected_missing_module = f'cellacdc.models.{model_name}' + if e.name != expected_missing_module: + raise + # Check if custom model cp = config.ConfigParser() cp.read(models_list_file_path) - model_path = cp[model_name]['path'] + model_key = None + for key in (original_model_name, model_name): + if key in cp: + model_key = key + break + + if model_key is None: + raise + + model_path = cp[model_key]['path'] spec = importlib.util.spec_from_file_location('acdcSegment', model_path) acdcSegment = importlib.util.module_from_spec(spec) sys.modules['acdcSegment'] = acdcSegment @@ -5121,7 +5172,6 @@ def get_empty_stored_data_dict(): 'delROIs_info': { 'rois': [], 'delMasks': [], 'delIDsROI': [], 'state': [] }, - 'IDs': [], 'manually_edited_lab': {'lab': {}, 'zoom_slice': None} } @@ -5503,7 +5553,7 @@ def find_distances_ID(rps, point=None, ID=None): if ID is not None and point is None: try: - point = [rp.centroid for rp in rps if rp.label == ID][0] + point = rps.get_centroid(ID) except IndexError: raise ValueError(f'ID {ID} not found in regionprops (list of cells).') @@ -5515,7 +5565,7 @@ def find_distances_ID(rps, point=None, ID=None): point = point[::-1] # rp are in (y, x) format (or (z, y, x) for 3D data) so I need to reverse order point = np.array([point]) - centroids = np.array([rp.centroid for rp in rps]) + centroids = np.array([rps.get_centroid(ID) for ID in rps.IDs]) diff = point[:, np.newaxis] - centroids dist_matrix = np.linalg.norm(diff, axis=2) return dist_matrix @@ -5552,7 +5602,7 @@ def sort_IDs_dist(rps, point=None, ID=None): """ if ID is not None and point is None: try: - point = [rp.centroid for rp in rps if rp.label == ID][0] + point = rps.get_centroid(ID) except IndexError: raise ValueError(f'ID {ID} not found in regionprops (list of cells).') @@ -5563,7 +5613,7 @@ def sort_IDs_dist(rps, point=None, ID=None): raise ValueError('Only one of ID or point must be provided.') - IDs = [rp.label for rp in rps] + IDs = rps.IDs if len(IDs) == 0: return [] elif len(IDs) == 1: diff --git a/cellacdc/plot.py b/cellacdc/plot.py index 3d33b3c23..34ffe812d 100644 --- a/cellacdc/plot.py +++ b/cellacdc/plot.py @@ -15,6 +15,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable import seaborn as sns +from . import debugutils from tqdm import tqdm from . import GUI_INSTALLED @@ -25,7 +26,7 @@ from . import printl from . import _core, error_below, error_close -from . import _run, core, myutils +from . import _run, core, myutils, regionprops as acdc_regionprops def matplotlib_cmap_to_lut( cmap: Union[Iterable, matplotlib.colors.Colormap, str], @@ -768,19 +769,22 @@ def plt_contours( clear_borders=True, obj_contours_kwargs=None ): if rp is None: - rp = skimage.measure.regionprops(lab) + rp = acdc_regionprops.acdcRegionprops(lab, precache_centroids=False) if plot_kwargs is None: plot_kwargs = {} if obj_contours_kwargs is None: obj_contours_kwargs = {} + elif 'include_internal' in obj_contours_kwargs: + obj_contours_kwargs = obj_contours_kwargs.copy() + obj_contours_kwargs['all'] = obj_contours_kwargs.pop('include_internal') for obj in rp: if only_IDs is not None and obj.label not in only_IDs: continue - contours = core.get_obj_contours(obj, **obj_contours_kwargs) + contours = core.get_obj_contours(obj=obj, **obj_contours_kwargs) if not isinstance(contours, list): contours = [contours] diff --git a/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd new file mode 100644 index 000000000..6728157cd Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd new file mode 100644 index 000000000..951ad121d Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd new file mode 100644 index 000000000..9d5d0547f Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd new file mode 100644 index 000000000..c0ad5b52b Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so new file mode 100644 index 000000000..0b70d8007 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 000000000..e84696138 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so new file mode 100644 index 000000000..9e46a0cb0 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 000000000..96b91d778 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so new file mode 100644 index 000000000..2a82f1217 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 000000000..14c6dd994 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so new file mode 100644 index 000000000..d715591ef Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so new file mode 100644 index 000000000..f19d4d7f0 Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled_functions.pyx b/cellacdc/precompiled_functions.pyx new file mode 100644 index 000000000..f95298411 --- /dev/null +++ b/cellacdc/precompiled_functions.pyx @@ -0,0 +1,493 @@ +# precompiled_functions.pyx +# cython: boundscheck=False, wraparound=False, cdivision=True +# rand change to trigger gh actions: 2 +import numpy as np +cimport numpy as np +from libc.limits cimport UINT_MAX + +def find_all_objects_2D(np.uint32_t[:, :] label_img): + cdef Py_ssize_t n_rows = label_img.shape[0] + cdef Py_ssize_t n_cols = label_img.shape[1] + cdef Py_ssize_t i, j + cdef unsigned int label, max_label = 0 + cdef unsigned int capacity = 300, new_cap + + cdef np.ndarray[np.uint32_t, ndim=1] _rs = np.full(capacity, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _re = np.zeros(capacity, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _cs = np.full(capacity, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _ce = np.zeros(capacity, dtype=np.uint32) + + cdef unsigned int[:] rs = _rs, re = _re, cs = _cs, ce = _ce + + # Single pass: compute bounding boxes, growing arrays in 300-label steps if needed + for i in range(n_rows): + for j in range(n_cols): + label = label_img[i, j] + if label == 0: + continue + if label >= capacity: + new_cap = ((label // 300) + 1) * 300 + _rs = np.concatenate((_rs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32))) + _re = np.concatenate((_re, np.zeros(new_cap - capacity, dtype=np.uint32))) + _cs = np.concatenate((_cs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32))) + _ce = np.concatenate((_ce, np.zeros(new_cap - capacity, dtype=np.uint32))) + rs = _rs; re = _re; cs = _cs; ce = _ce + capacity = new_cap + if label > max_label: + max_label = label + if i < rs[label]: rs[label] = i + if i + 1 > re[label]: re[label] = (i + 1) + if j < cs[label]: cs[label] = j + if j + 1 > ce[label]: ce[label] = (j + 1) + + if max_label == 0: + return [], [] + + # Collect present labels into compact numpy arrays (avoids per-label tuple allocation) + cdef unsigned int n_labels = 0 + for lbl in range(1, max_label + 1): + if re[lbl] != 0: + n_labels += 1 + + cdef np.ndarray[np.uint32_t, ndim=1] out_labels = np.empty(n_labels, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=2] out_bboxes = np.empty((n_labels, 4), dtype=np.uint32) + cdef unsigned int idx = 0 + for lbl in range(1, max_label + 1): + if re[lbl] != 0: + out_labels[idx] = lbl + out_bboxes[idx, 0] = rs[lbl] + out_bboxes[idx, 1] = re[lbl] + out_bboxes[idx, 2] = cs[lbl] + out_bboxes[idx, 3] = ce[lbl] + idx += 1 + return out_labels, out_bboxes + +def find_all_objects_3D(np.uint32_t[:, :, :] label_img): + cdef Py_ssize_t n_z = label_img.shape[0] + cdef Py_ssize_t n_rows = label_img.shape[1] + cdef Py_ssize_t n_cols = label_img.shape[2] + cdef Py_ssize_t i, j, k + cdef unsigned int label, max_label = 0 + cdef unsigned int capacity = 300, new_cap + + cdef np.ndarray[np.uint32_t, ndim=1] _zs = np.full(capacity, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _ze = np.zeros(capacity, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _rs = np.full(capacity, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _re = np.zeros(capacity, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _cs = np.full(capacity, UINT_MAX, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=1] _ce = np.zeros(capacity, dtype=np.uint32) + + cdef unsigned int[:] zs = _zs, ze = _ze, rs = _rs, re = _re, cs = _cs, ce = _ce + + # Single pass: compute bounding boxes, growing arrays in 300-label steps if needed + for i in range(n_z): + for j in range(n_rows): + for k in range(n_cols): + label = label_img[i, j, k] + if label == 0: + continue + if label >= capacity: + new_cap = ((label // 300) + 1) * 300 + _zs = np.concatenate((_zs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32))) + _ze = np.concatenate((_ze, np.zeros(new_cap - capacity, dtype=np.uint32))) + _rs = np.concatenate((_rs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32))) + _re = np.concatenate((_re, np.zeros(new_cap - capacity, dtype=np.uint32))) + _cs = np.concatenate((_cs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32))) + _ce = np.concatenate((_ce, np.zeros(new_cap - capacity, dtype=np.uint32))) + zs = _zs; ze = _ze; rs = _rs; re = _re; cs = _cs; ce = _ce + capacity = new_cap + if label > max_label: + max_label = label + if i < zs[label]: zs[label] = i + if i + 1 > ze[label]: ze[label] = (i + 1) + if j < rs[label]: rs[label] = j + if j + 1 > re[label]: re[label] = (j + 1) + if k < cs[label]: cs[label] = k + if k + 1 > ce[label]: ce[label] = (k + 1) + + if max_label == 0: + return [], [] + + # Collect present labels into compact numpy arrays (avoids per-label tuple allocation) + cdef unsigned int n_labels = 0 + for lbl in range(1, max_label + 1): + if ze[lbl] != 0: + n_labels += 1 + + cdef np.ndarray[np.uint32_t, ndim=1] out_labels = np.empty(n_labels, dtype=np.uint32) + cdef np.ndarray[np.uint32_t, ndim=2] out_bboxes = np.empty((n_labels, 6), dtype=np.uint32) + cdef unsigned int idx = 0 + for lbl in range(1, max_label + 1): + if ze[lbl] != 0: + out_labels[idx] = lbl + out_bboxes[idx, 0] = zs[lbl] + out_bboxes[idx, 1] = ze[lbl] + out_bboxes[idx, 2] = rs[lbl] + out_bboxes[idx, 3] = re[lbl] + out_bboxes[idx, 4] = cs[lbl] + out_bboxes[idx, 5] = ce[lbl] + idx += 1 + return out_labels, out_bboxes + +def most_common_projection_3D(np.uint32_t[:, :, :] lab, int axis): + """Most-common-value projection for a 3-D label image along `axis`. + + Tie-break matches np.unique(..., return_counts=True) + np.argmax(counts), + i.e. the smallest label wins when counts are equal. + """ + if axis < 0 or axis > 2: + raise ValueError(f'axis must be 0, 1, or 2. Got {axis}.') + + cdef Py_ssize_t z = lab.shape[0] + cdef Py_ssize_t y = lab.shape[1] + cdef Py_ssize_t x = lab.shape[2] + cdef Py_ssize_t i, j, a, b, depth + cdef unsigned int v, vv + cdef unsigned int best_label, best_count, curr_count + cdef bint seen + cdef np.uint32_t[:, :] out_view + + if axis == 0: + depth = z + out = np.empty((y, x), dtype=np.uint32) + out_view = out + for i in range(y): + for j in range(x): + best_count = 0 + best_label = 0 + for a in range(depth): + v = lab[a, i, j] + if v == 0: + continue + seen = False + for b in range(a): + if lab[b, i, j] == v: + seen = True + break + if seen: + continue + + # Count all remaining occurrences of this label along the full axis. + curr_count = 1 + for b in range(a + 1, depth): + if lab[b, i, j] == v: + curr_count += 1 + + if curr_count > best_count or (curr_count == best_count and v < best_label): + best_count = curr_count + best_label = v + + out_view[i, j] = best_label + return out + + if axis == 1: + depth = y + out = np.empty((z, x), dtype=np.uint32) + out_view = out + for i in range(z): + for j in range(x): + best_count = 0 + best_label = 0 + for a in range(depth): + v = lab[i, a, j] + if v == 0: + continue + seen = False + for b in range(a): + if lab[i, b, j] == v: + seen = True + break + if seen: + continue + + curr_count = 1 + for b in range(a + 1, depth): + if lab[i, b, j] == v: + curr_count += 1 + + if curr_count > best_count or (curr_count == best_count and v < best_label): + best_count = curr_count + best_label = v + + out_view[i, j] = best_label + return out + + depth = x + out = np.empty((z, y), dtype=np.uint32) + out_view = out + for i in range(z): + for j in range(y): + best_count = 0 + best_label = 0 + for a in range(depth): + v = lab[i, j, a] + if v == 0: + continue + seen = False + for b in range(a): + vv = lab[i, j, b] + if vv == v: + seen = True + break + if seen: + continue + + curr_count = 1 + for b in range(a + 1, depth): + vv = lab[i, j, b] + if vv == v: + curr_count += 1 + + if curr_count > best_count or (curr_count == best_count and v < best_label): + best_count = curr_count + best_label = v + + out_view[i, j] = best_label + return out + +def object_projections_and_size_3D( + np.uint32_t[:, :, :] cutout, + unsigned int obj_id, +): + """Return binary XY/XZ/YZ projections and voxel count for specified object in a cutout.""" + cdef Py_ssize_t z = cutout.shape[0] + cdef Py_ssize_t y = cutout.shape[1] + cdef Py_ssize_t x = cutout.shape[2] + cdef Py_ssize_t i, j, k + cdef unsigned int size = 0 + + cdef np.ndarray[np.uint8_t, ndim=2] proj_z = np.zeros((y, x), dtype=np.uint8) + cdef np.ndarray[np.uint8_t, ndim=2] proj_y = np.zeros((z, x), dtype=np.uint8) + cdef np.ndarray[np.uint8_t, ndim=2] proj_x = np.zeros((z, y), dtype=np.uint8) + cdef np.uint8_t[:, :] proj_z_view = proj_z + cdef np.uint8_t[:, :] proj_y_view = proj_y + cdef np.uint8_t[:, :] proj_x_view = proj_x + + for i in range(z): + for j in range(y): + for k in range(x): + if cutout[i, j, k] != obj_id: + continue + size += 1 + proj_z_view[j, k] = 1 + proj_y_view[i, k] = 1 + proj_x_view[i, j] = 1 + + return proj_z, proj_y, proj_x, size + +def object_projection_and_size_3D( + np.uint32_t[:, :, :] cutout, + unsigned int obj_id, + int axis, +): + """Return one binary projection and voxel count for one object in a 3-D cutout. + + axis=0 -> XY projection (collapse z) + axis=1 -> XZ projection (collapse y) + axis=2 -> YZ projection (collapse x) + """ + if axis < 0 or axis > 2: + raise ValueError(f'axis must be 0, 1, or 2. Got {axis}.') + + cdef Py_ssize_t z = cutout.shape[0] + cdef Py_ssize_t y = cutout.shape[1] + cdef Py_ssize_t x = cutout.shape[2] + cdef Py_ssize_t i, j, k + cdef unsigned int size = 0 + + cdef np.ndarray[np.uint8_t, ndim=2] proj + cdef np.uint8_t[:, :] proj_view + + if axis == 0: + proj = np.zeros((y, x), dtype=np.uint8) + proj_view = proj + for i in range(z): + for j in range(y): + for k in range(x): + if cutout[i, j, k] != obj_id: + continue + size += 1 + proj_view[j, k] = 1 + return proj, size + + if axis == 1: + proj = np.zeros((z, x), dtype=np.uint8) + proj_view = proj + for i in range(z): + for j in range(y): + for k in range(x): + if cutout[i, j, k] != obj_id: + continue + size += 1 + proj_view[i, k] = 1 + return proj, size + + proj = np.zeros((z, y), dtype=np.uint8) + proj_view = proj + for i in range(z): + for j in range(y): + for k in range(x): + if cutout[i, j, k] != obj_id: + continue + size += 1 + proj_view[i, j] = 1 + + return proj, size + +def calc_IoA_matrix_2D( + np.uint32_t[:, :] lab, + np.uint32_t[:, :] prev_lab, + np.uint32_t[:] curr_IDs, + np.uint32_t[:] prev_IDs, + np.uint32_t[:] prev_areas, + np.uint32_t[:] curr_areas, + bint use_union, +): + """Single-pass IoA matrix between two 2-D label images. + + Parameters + ---------- + lab, prev_lab : (Y, X) uint32 label images for current and previous frame. + curr_IDs : 1-D array of current object labels (row order of output). + prev_IDs : 1-D array of previous object labels (col order of output). + prev_areas : pixel area of each entry in prev_IDs. + curr_areas : pixel area of each entry in curr_IDs (only used when use_union=True). + use_union : if False, denominator is area_prev; if True, denominator is union. + + Returns + ------- + IoA_matrix : (n_curr, n_prev) float64 array. + """ + cdef Py_ssize_t n_rows = lab.shape[0] + cdef Py_ssize_t n_cols = lab.shape[1] + cdef Py_ssize_t n_curr = curr_IDs.shape[0] + cdef Py_ssize_t n_prev = prev_IDs.shape[0] + cdef Py_ssize_t i, j, ci, pi + cdef unsigned int c, p, max_curr_label = 0, max_prev_label = 0 + cdef int ci_val, pi_val + + for i in range(n_curr): + if curr_IDs[i] > max_curr_label: + max_curr_label = curr_IDs[i] + for i in range(n_prev): + if prev_IDs[i] > max_prev_label: + max_prev_label = prev_IDs[i] + + # label -> matrix-index lookup; -1 means "not in the tracked set" + cdef np.ndarray[np.int32_t, ndim=1] _curr_idx = np.full(max_curr_label + 1, -1, dtype=np.int32) + cdef np.ndarray[np.int32_t, ndim=1] _prev_idx = np.full(max_prev_label + 1, -1, dtype=np.int32) + cdef int[:] curr_idx = _curr_idx + cdef int[:] prev_idx = _prev_idx + + for i in range(n_curr): + curr_idx[curr_IDs[i]] = i + for i in range(n_prev): + prev_idx[prev_IDs[i]] = i + + cdef np.ndarray[np.uint32_t, ndim=2] _intersections = np.zeros((n_curr, n_prev), dtype=np.uint32) + cdef unsigned int[:, :] intersections = _intersections + + # Single pass: count overlapping pixels between every (curr, prev) pair + for i in range(n_rows): + for j in range(n_cols): + c = lab[i, j] + p = prev_lab[i, j] + if c == 0 or p == 0: + continue + if c > max_curr_label or p > max_prev_label: + continue + ci_val = curr_idx[c] + pi_val = prev_idx[p] + if ci_val < 0 or pi_val < 0: + continue + intersections[ci_val, pi_val] += 1 + + cdef np.ndarray[np.float64_t, ndim=2] IoA_matrix = np.zeros((n_curr, n_prev), dtype=np.float64) + cdef double denom_val, I_val + + for ci in range(n_curr): + for pi in range(n_prev): + I_val = intersections[ci, pi] + if I_val == 0.0: + continue + if use_union: + denom_val = (curr_areas[ci] + prev_areas[pi]) - I_val + else: + denom_val = prev_areas[pi] + if denom_val == 0.0: + continue + IoA_matrix[ci, pi] = I_val / denom_val + + return IoA_matrix + +def calc_IoA_matrix_3D( + np.uint32_t[:, :, :] lab, + np.uint32_t[:, :, :] prev_lab, + np.uint32_t[:] curr_IDs, + np.uint32_t[:] prev_IDs, + np.uint32_t[:] prev_areas, + np.uint32_t[:] curr_areas, + bint use_union, +): + """Single-pass IoA matrix between two 3-D label images. See calc_IoA_matrix_2D.""" + cdef Py_ssize_t n_z = lab.shape[0] + cdef Py_ssize_t n_rows = lab.shape[1] + cdef Py_ssize_t n_cols = lab.shape[2] + cdef Py_ssize_t n_curr = curr_IDs.shape[0] + cdef Py_ssize_t n_prev = prev_IDs.shape[0] + cdef Py_ssize_t i, j, k, ci, pi + cdef unsigned int c, p, max_curr_label = 0, max_prev_label = 0 + cdef int ci_val, pi_val + + for i in range(n_curr): + if curr_IDs[i] > max_curr_label: + max_curr_label = curr_IDs[i] + for i in range(n_prev): + if prev_IDs[i] > max_prev_label: + max_prev_label = prev_IDs[i] + + cdef np.ndarray[np.int32_t, ndim=1] _curr_idx = np.full(max_curr_label + 1, -1, dtype=np.int32) + cdef np.ndarray[np.int32_t, ndim=1] _prev_idx = np.full(max_prev_label + 1, -1, dtype=np.int32) + cdef int[:] curr_idx = _curr_idx + cdef int[:] prev_idx = _prev_idx + + for i in range(n_curr): + curr_idx[curr_IDs[i]] = i + for i in range(n_prev): + prev_idx[prev_IDs[i]] = i + + cdef np.ndarray[np.uint32_t, ndim=2] _intersections = np.zeros((n_curr, n_prev), dtype=np.uint32) + cdef unsigned int[:, :] intersections = _intersections + + for i in range(n_z): + for j in range(n_rows): + for k in range(n_cols): + c = lab[i, j, k] + p = prev_lab[i, j, k] + if c == 0 or p == 0: + continue + if c > max_curr_label or p > max_prev_label: + continue + ci_val = curr_idx[c] + pi_val = prev_idx[p] + if ci_val < 0 or pi_val < 0: + continue + intersections[ci_val, pi_val] += 1 + + cdef np.ndarray[np.float64_t, ndim=2] IoA_matrix = np.zeros((n_curr, n_prev), dtype=np.float64) + cdef double denom_val, I_val + + for ci in range(n_curr): + for pi in range(n_prev): + I_val = intersections[ci, pi] + if I_val == 0.0: + continue + if use_union: + denom_val = (curr_areas[ci] + prev_areas[pi]) - I_val + else: + denom_val = prev_areas[pi] + if denom_val == 0.0: + continue + IoA_matrix[ci, pi] = I_val / denom_val + + return IoA_matrix \ No newline at end of file diff --git a/cellacdc/regionprops.py b/cellacdc/regionprops.py new file mode 100644 index 000000000..4f449db94 --- /dev/null +++ b/cellacdc/regionprops.py @@ -0,0 +1,1203 @@ +import numpy as np +from scipy import ndimage as ndi +import skimage.measure +import cv2 +from . import printl, debugutils +from skimage.measure._regionprops_utils import ( + _normalize_spacing, +) +import traceback as traceback + +try: + from cellacdc.precompiled.precompiled_functions import ( + find_all_objects_2D, + find_all_objects_3D, + object_projections_and_size_3D, + object_projection_and_size_3D, + ) + _CYTHON_FIND_OBJECTS = True + _CYTHON_OBJECT_PROJECTIONS = True + print('regionprops: imported precompiled find-objects helpers.') +except Exception: + _CYTHON_FIND_OBJECTS = False + _CYTHON_OBJECT_PROJECTIONS = False + print('[WARNING]: regionprops could not import precompiled find-objects helpers, falling back to scipy.ndimage.find_objects.') + +try: + from cellacdc.precompiled.precompiled_functions import most_common_projection_3D + _CYTHON_MOST_COMMON_PROJECTION = True + print('regionprops: imported precompiled most-common projection helper.') +except Exception: + _CYTHON_MOST_COMMON_PROJECTION = False + print('[WARNING]: regionprops could not import precompiled most-common projection helper, falling back to NumPy implementation.') + +# WARNING: Developers have already used +# 14 hrs +# to optimize this. +# In addition, implementing these optimizations in the codebase took +# 9 hrs +# Specifically the +# centroid (huge gain for 3D data) +# contour caching +# better find objects implementation to avoid iterating over None lists +# bbox caching +# targeted updates to RP +# stuff was targeted. +# If you decide to try and optimize it further, please update this warning :) + +_RegionProperties = skimage.measure._regionprops.RegionProperties +_cached = skimage.measure._regionprops._cached + +def _most_common_projection_ignore_zero_numpy(lab, axis): + """Most-common projection that ignores label 0 unless all values are 0.""" + moved = np.moveaxis(lab, axis, 0) + depth = moved.shape[0] + flat = moved.reshape(depth, -1) + out = np.zeros(flat.shape[1], dtype=lab.dtype) + + for col in range(flat.shape[1]): + line = flat[:, col] + nonzero = line[line != 0] + if nonzero.size == 0: + continue + + labels, counts = np.unique(nonzero, return_counts=True) + # np.unique sorts ascending, so ties are resolved by smallest label. + out[col] = labels[np.argmax(counts)] + + return out.reshape(moved.shape[1:]) + +def _object_projection_and_size_numpy(cutout, obj_id, axis): + mask = cutout == obj_id + proj = np.any(mask, axis=axis).astype(np.uint8) + size = int(np.count_nonzero(mask)) + return proj, size + +def _acdc_regionprops_factory( + label_image, + intensity_image=None, + cache=True, + *, + extra_properties=None, + spacing=None, + offset=None, + ): + if label_image.ndim not in (2, 3): + raise TypeError('Only 2-D and 3-D images supported.') + + if not np.issubdtype(label_image.dtype, np.integer): + if np.issubdtype(label_image.dtype, bool): + raise TypeError( + 'Non-integer image types are ambiguous: ' + 'use skimage.measure.label to label the connected ' + 'components of label_image, ' + 'or label_image.astype(np.uint8) to interpret ' + 'the True values as a single label.' + ) + raise TypeError('Non-integer label_image types are ambiguous') + + if offset is None: + offset_arr = np.zeros((label_image.ndim,), dtype=int) + else: + offset_arr = np.asarray(offset) + if offset_arr.ndim != 1 or offset_arr.size != label_image.ndim: + raise ValueError( + 'Offset should be an array-like of integers ' + 'of shape (label_image.ndim,); ' + f'{offset} was provided.' + ) + + regions = [] + if _CYTHON_FIND_OBJECTS: + img_uint32 = label_image.astype(np.uint32, copy=False) + if label_image.ndim == 2: + out = find_all_objects_2D(img_uint32) + labels, bboxes = out + for i in range(len(labels)): + sl = (slice(int(bboxes[i, 0]), int(bboxes[i, 1])), + slice(int(bboxes[i, 2]), int(bboxes[i, 3]))) + regions.append(acdcRegionProperties( + sl, int(labels[i]), label_image, intensity_image, cache, + spacing=spacing, extra_properties=extra_properties, + offset=offset_arr, + )) + else: + out = find_all_objects_3D(img_uint32) + labels, bboxes = out + for i in range(len(labels)): + sl = (slice(int(bboxes[i, 0]), int(bboxes[i, 1])), + slice(int(bboxes[i, 2]), int(bboxes[i, 3])), + slice(int(bboxes[i, 4]), int(bboxes[i, 5]))) + regions.append(acdcRegionProperties( + sl, int(labels[i]), label_image, intensity_image, cache, + spacing=spacing, extra_properties=extra_properties, + offset=offset_arr, + )) + else: + objects = ndi.find_objects(label_image) + for i, sl in enumerate(objects, start=1): + if sl is None: + continue + regions.append(acdcRegionProperties( + sl, i, label_image, intensity_image, cache, + spacing=spacing, extra_properties=extra_properties, + offset=offset_arr, + )) + return regions + +class acdcRegionProperties(_RegionProperties): + def __init__( + self, + slice, + label, + label_image, + intensity_image, + cache_active, + *, + extra_properties=None, + spacing=None, + offset=None, + ): + super().__init__( + slice, label, label_image, intensity_image, cache_active, + extra_properties=extra_properties, spacing=spacing, offset=offset + ) + # @property + # @_cached + # def slice(self): + # # scale slice with offset + # return tuple( + # slice(self._slice[i].start + self._offset[i], + # self._slice[i].stop + self._offset[i]) + # for i in range(self._ndim) + # ) + + @property + def image(self): + """Return cached object mask from the current label image.""" + imgage = self._cache.get('image') + if imgage is None or not np.any(imgage): + self._cache['image'] = self._label_image[self._slice] == self.label + + return self._cache['image'] + + @property + @_cached + def bbox(self): + """ + Returns + ------- + A tuple of the bounding box's start coordinates for each dimension, + followed by the end coordinates for each dimension. + """ + return tuple( + [self.slice[i].start for i in range(self._ndim)] + + [self.slice[i].stop for i in range(self._ndim)] + ) + + @property + @_cached # slow for 3D data, better cache it + def centroid(self): + return super().centroid + + @property + @_cached + def contour(self): + contours = self._contours_local(retrieve_mode=cv2.RETR_EXTERNAL) + if not contours: + return np.empty((0, 2), dtype=np.int32) + + contour = max(contours, key=len) + contour = np.squeeze(contour, axis=1) + contour = np.vstack((contour, contour[0])) + return contour + self._xy_offset + + @property + @_cached + def contour_all(self): + # Include both outer boundaries and holes. + contours = self._contours_local(retrieve_mode=cv2.RETR_CCOMP) + if not contours: + return [] + offset = self._xy_offset + return [np.squeeze(cont, axis=1) + offset for cont in contours] + + @property + @_cached + def _xy_offset(self): + if self._ndim != 2: + raise AttributeError('contour is only supported for 2D objects.') + slc = self.slice + return np.array([slc[1].start, slc[0].start], dtype=np.int32) + + def _contours_local(self, retrieve_mode=cv2.RETR_EXTERNAL): + if self._ndim != 2: + raise AttributeError('contour is only supported for 2D objects.') + obj_image = np.ascontiguousarray(self.image, dtype=np.uint8) + contours, _ = cv2.findContours( + obj_image, retrieve_mode, cv2.CHAIN_APPROX_NONE + ) + return contours + +class acdcRegionprops: + def __init__( + self, + lab, + acdc_df=None, + centroids_loaded=None, + IDs_loaded=None, + centroids_IDs_exact_loaded=None, + ID_to_idx_loaded=None, + precache_centroids=True, + **kwargs, + ): + self.lab = lab + self.acdc_df = acdc_df + self._rp = _acdc_regionprops_factory(lab, **kwargs) + self.is3D = self.lab.ndim == 3 + self._slice_rps = { + 'z': {}, + 'y': {}, + 'x': {}, + } + self._proj_rps = { + 'z': {}, + 'y': {}, + 'x': {}, + } + self._proj_labs = { + 'z': {}, + 'y': {}, + 'x': {}, + } + self._centroid_mapper = {} + self._centroid_IDs_exact = set() + if IDs_loaded is None or ID_to_idx_loaded is None: + self.set_attributes(update_centroid_mapper=False) + else: + self.ID_to_idx = ID_to_idx_loaded + self.IDs_set = set(IDs_loaded) + self.IDs = list(self.IDs_set) + + if centroids_IDs_exact_loaded is not None and centroids_loaded is not None: + self._centroid_mapper = centroids_loaded + self._centroid_IDs_exact = set(centroids_IDs_exact_loaded) + elif precache_centroids: + self.precache_centroids() + + else: + self._centroid_mapper = dict() + + def get_slice_rp(self, slice_number, slicing='z'): + if not self.is3D: + raise ValueError('Slice-specific regionprops are only supported for 3D labels.') + + slicing = self._normalize_slicing(slicing) + slice_number = int(slice_number) + self._validate_slice_number(slice_number, slicing) + + rp = self._slice_rps[slicing].get(slice_number) + if rp is None: + lab_slice = self._get_lab_slice(self.lab, slice_number, slicing) + rp = acdcRegionprops(lab_slice, precache_centroids=False) + self._slice_rps[slicing][slice_number] = rp + return rp + + def get_proj_rp(self, kind='max', slicing='z'): + if not self.is3D: + raise ValueError('Projection-specific regionprops are only supported for 3D labels.') + + slicing = self._normalize_slicing(slicing) + kind = self._normalize_projection_kind(kind) + + rp = self._proj_rps[slicing].get(kind) + if rp is None: + lab_proj = self._proj_labs[slicing].get(kind) + if lab_proj is None: + lab_proj = self._get_lab_projection(self.lab, slicing=slicing, kind=kind) + self._proj_labs[slicing][kind] = lab_proj + rp = acdcRegionprops(lab_proj, precache_centroids=False) + self._proj_rps[slicing][kind] = rp + return rp + + def get_obj_from_slice_rp(self, ID, slice_number, slicing='z', warn=True): + rp = self.get_slice_rp(slice_number, slicing=slicing) + return rp.get_obj_from_ID(ID, warn=warn) + + def get_obj_from_proj_rp(self, ID, kind='max', slicing='z', warn=True): + kind = self._normalize_projection_kind(kind) + rp = self.get_proj_rp(kind=kind, slicing=slicing) + return rp.get_obj_from_ID(ID, warn=warn) + + def __iter__(self): + return iter(self._rp) + + def __len__(self): + return len(self._rp) + + def __getitem__(self, idx): + return self._rp[idx] + + def __setitem__(self, idx, value): + self._rp[idx] = value + + def __repr__(self): + return repr(self._rp) + + def _normalize_slicing(self, slicing): + slicing = str(slicing).lower() + if slicing not in ('z', 'y', 'x'): + raise ValueError( + f'Invalid slicing "{slicing}". Valid options are "z", "y", and "x".' + ) + return slicing + + def _slice_axis_index(self, slicing): + axis_map = {'z': 0, 'y': 1, 'x': 2} + return axis_map[slicing] + + def _normalize_projection_kind(self, kind): + kind = str(kind).lower().strip() + kind_norm = kind.replace('-', ' ').replace('_', ' ') + + if kind_norm.startswith('max'): + return 'max' + if kind_norm.startswith('mean'): + return 'mean' + if kind_norm.startswith('median'): + return 'median' + if kind_norm.startswith('most common') or kind_norm.startswith('mode'): + return 'most_common' + + if kind not in ('max', 'mean', 'median', 'most_common'): + raise ValueError( + f'Invalid projection kind "{kind}". ' + 'Valid options are "max", "mean", "median", and "most_common".' + ) + return kind + + def _validate_slice_number(self, slice_number, slicing): + axis = self._slice_axis_index(slicing) + axis_size = self.lab.shape[axis] + if slice_number < 0 or slice_number >= axis_size: + raise IndexError( + f'Slice number {slice_number} is out of bounds for slicing "{slicing}" ' + f'with size {axis_size}.' + ) + + def _has_initialized_slice_rps(self): + return any(len(slice_dict) > 0 for slice_dict in self._slice_rps.values()) + + def _has_initialized_proj_rps(self): + return any(len(proj_dict) > 0 for proj_dict in self._proj_rps.values()) + + def _iter_initialized_slice_rps(self): + for slicing, slice_dict in self._slice_rps.items(): + for slice_number, rp in slice_dict.items(): + yield slicing, slice_number, rp + + def _iter_initialized_proj_rps(self): + for slicing, proj_dict in self._proj_rps.items(): + for kind, rp in proj_dict.items(): + yield slicing, kind, rp + + def _get_lab_slice(self, lab, slice_number, slicing): + if lab.ndim != 3: + raise ValueError( + f'Slice-specific regionprops are only supported for 3D labels, got {lab.ndim}D.' + ) + + slicing = self._normalize_slicing(slicing) + if slicing == 'z': + return lab[slice_number, :, :] + if slicing == 'y': + return lab[:, slice_number, :] + return lab[:, :, slice_number] + + def _get_lab_projection(self, lab, slicing='z', kind='max'): + if lab.ndim != 3: + raise ValueError( + f'Projection-specific regionprops are only supported for 3D labels, got {lab.ndim}D.' + ) + + axis = self._slice_axis_index(self._normalize_slicing(slicing)) + kind = self._normalize_projection_kind(kind) + if kind == 'max': + return np.max(lab, axis=axis) + + if kind == 'most_common': + return self._compute_most_common_projection(lab, axis=axis) + + if kind == 'mean': + projected = np.mean(lab, axis=axis) + else: + projected = np.median(lab, axis=axis) + + # Regionprops requires integer labels. + return np.rint(projected).astype(lab.dtype, copy=False) + + def _compute_most_common_projection(self, lab, axis): + if _CYTHON_MOST_COMMON_PROJECTION: + lab_uint32 = lab.astype(np.uint32, copy=False) + projected = most_common_projection_3D(lab_uint32, int(axis)) + return projected.astype(lab.dtype, copy=False) + + projected = _most_common_projection_ignore_zero_numpy(lab, axis) + return projected.astype(lab.dtype, copy=False) + + def _projection_canvas_shape(self, slicing): + slicing = self._normalize_slicing(slicing) + if slicing == 'z': + return (self.lab.shape[1], self.lab.shape[2]) + if slicing == 'y': + return (self.lab.shape[0], self.lab.shape[2]) + return (self.lab.shape[0], self.lab.shape[1]) + + def _project_cutout_for_slicing_and_size(self, cutout, obj_id, slicing): + axis = self._slice_axis_index(slicing) + if _CYTHON_OBJECT_PROJECTIONS: + cutout_uint32 = cutout.astype(np.uint32, copy=False) + proj, size = object_projection_and_size_3D( + cutout_uint32, int(obj_id), int(axis) + ) + return proj.astype(bool, copy=False), int(size) + + proj, size = _object_projection_and_size_numpy(cutout, int(obj_id), int(axis)) + return proj.astype(bool, copy=False), int(size) + + def _project_object_from_bbox(self, lab, obj_id, bbox, slicing): + z0, y0, x0, z1, y1, x1 = [int(v) for v in bbox] + if z0 >= z1 or y0 >= y1 or x0 >= x1: + return None + + cutout = lab[z0:z1, y0:y1, x0:x1] + proj_mask, size = self._project_cutout_for_slicing_and_size( + cutout, obj_id, slicing + ) + if size == 0: + return None + + if slicing == 'z': + out_slice = (slice(y0, y1), slice(x0, x1)) + elif slicing == 'y': + out_slice = (slice(z0, z1), slice(x0, x1)) + else: + out_slice = (slice(z0, z1), slice(y0, y1)) + + return out_slice, proj_mask, int(size) + + def _iter_projected_objects_sorted(self, slicing='z', lab=None): + if not self.is3D: + raise ValueError('Projection helpers are only supported for 3D labels.') + + slicing = self._normalize_slicing(slicing) + if lab is None: + lab = self.lab + + projected = [] + for obj in self._rp: + if len(obj.bbox) != 6: + continue + out = self._project_object_from_bbox(lab, obj.label, obj.bbox, slicing) + if out is None: + continue + out_slice, proj_mask, size = out + projected.append((int(obj.label), int(size), out_slice, proj_mask)) + + # Draw large objects first, then smaller objects so small ones stay visible on top. + projected.sort(key=lambda entry: (-entry[1], entry[0])) + return projected + + def get_projection_lab_sorted(self, slicing='z', dtype=None): + if not self.is3D: + raise ValueError('Projection helpers are only supported for 3D labels.') + + slicing = self._normalize_slicing(slicing) + out_shape = self._projection_canvas_shape(slicing) + if dtype is None: + dtype = self.lab.dtype + + lab_proj = np.zeros(out_shape, dtype=dtype) + for label, _, out_slice, proj_mask in self._iter_projected_objects_sorted(slicing=slicing): + lab_proj[out_slice][proj_mask] = label + return lab_proj + + def _get_projection_patch_slices(self, slicing, cutout_bbox): + z0, y0, x0, z1, y1, x1 = [int(v) for v in cutout_bbox] + if slicing == 'z': + return (slice(y0, y1), slice(x0, x1)) + if slicing == 'y': + return (slice(z0, z1), slice(x0, x1)) + return (slice(z0, z1), slice(y0, y1)) + + def _get_projection_patch_lab(self, slicing, cutout_bbox): + z0, y0, x0, z1, y1, x1 = [int(v) for v in cutout_bbox] + if slicing == 'z': + return self.lab[:, y0:y1, x0:x1] + if slicing == 'y': + return self.lab[z0:z1, :, x0:x1] + return self.lab[z0:z1, y0:y1, :] + + def _compute_most_common_projection_patch(self, slicing, cutout_bbox): + patch_lab = self._get_projection_patch_lab(slicing, cutout_bbox) + axis = self._slice_axis_index(slicing) + return self._compute_most_common_projection(patch_lab, axis=axis) + + def _update_cached_most_common_projection_locally(self, slicing, cutout_bbox): + lab_proj = self._get_cached_or_new_lab_projection(slicing, 'most_common') + patch_slices = self._get_projection_patch_slices(slicing, cutout_bbox) + if any(slc.start >= slc.stop for slc in patch_slices): + return lab_proj + + patch = self._compute_most_common_projection_patch(slicing, cutout_bbox) + lab_proj[patch_slices] = patch + self._proj_labs[slicing]['most_common'] = lab_proj + return lab_proj + + def _get_cached_or_new_lab_projection(self, slicing, kind): + lab_proj = self._proj_labs[slicing].get(kind) + if lab_proj is None: + lab_proj = self._get_lab_projection(self.lab, slicing=slicing, kind=kind) + self._proj_labs[slicing][kind] = lab_proj + return lab_proj + + def _replace_cached_lab_projection(self, slicing, kind): + lab_proj = self._get_lab_projection(self.lab, slicing=slicing, kind=kind) + self._proj_labs[slicing][kind] = lab_proj + return lab_proj + + def _apply_assignments_to_lab_from_rp(self, rp, lab, assignments): + active_assignments = { + int(old_ID): int(new_ID) + for old_ID, new_ID in assignments.items() + if old_ID != new_ID + } + if not active_assignments: + return lab + + dst = lab.copy() + for obj in rp: + new_ID = active_assignments.get(obj.label) + if new_ID is None: + continue + dst[obj.slice][obj.image] = new_ID + return dst + + def _apply_deletions_to_lab_from_rp(self, rp, lab, IDs_to_delete): + IDs_to_delete = set(IDs_to_delete) + if not IDs_to_delete: + return lab + + dst = lab.copy() + for obj in rp: + if obj.label in IDs_to_delete: + dst[obj.slice][obj.image] = 0 + return dst + + def _sync_initialized_slice_rps_via_assignments(self, assignments): + if not self._has_initialized_slice_rps(): + return + + for slicing, slice_number, rp in self._iter_initialized_slice_rps(): + lab_slice = self._get_lab_slice(self.lab, slice_number, slicing) + rp.update_regionprops_via_assignments(assignments, lab_slice) + + def _sync_initialized_proj_rps_via_assignments(self, assignments): + if not self._has_initialized_proj_rps(): + return + + for slicing, kind, rp in self._iter_initialized_proj_rps(): + lab_proj = self._get_cached_or_new_lab_projection(slicing, kind) + lab_proj = self._apply_assignments_to_lab_from_rp( + rp, lab_proj, assignments + ) + self._proj_labs[slicing][kind] = lab_proj + rp.update_regionprops_via_assignments(assignments, lab_proj) + + def _sync_initialized_slice_rps_via_deletions(self, IDs_to_delete): + if not self._has_initialized_slice_rps(): + return + + for _, _, rp in self._iter_initialized_slice_rps(): + rp.update_regionprops_via_deletions(IDs_to_delete) + + def _sync_initialized_proj_rps_via_deletions(self, IDs_to_delete): + if not self._has_initialized_proj_rps(): + return + + for slicing, kind, rp in self._iter_initialized_proj_rps(): + lab_proj = self._get_cached_or_new_lab_projection(slicing, kind) + lab_proj = self._apply_deletions_to_lab_from_rp( + rp, lab_proj, IDs_to_delete + ) + self._proj_labs[slicing][kind] = lab_proj + rp.update_regionprops_via_deletions(IDs_to_delete) + + def _sync_initialized_slice_rps_via_update(self, specific_IDs_update_centroids=None): + if not self._has_initialized_slice_rps(): + return + + for slicing, slice_number, rp in self._iter_initialized_slice_rps(): + lab_slice = self._get_lab_slice(self.lab, slice_number, slicing) + rp.update_regionprops( + lab_slice, + specific_IDs_update_centroids=specific_IDs_update_centroids, + ) + + def _sync_initialized_proj_rps_via_update( + self, + specific_IDs_update_centroids=None, + cutout_bbox=None, + ): + if not self._has_initialized_proj_rps(): + return + + for slicing, kind, rp in self._iter_initialized_proj_rps(): + if cutout_bbox is not None and kind == 'most_common': + lab_proj = self._update_cached_most_common_projection_locally( + slicing, cutout_bbox + ) + else: + lab_proj = self._replace_cached_lab_projection(slicing, kind) + rp.update_regionprops( + lab_proj, + specific_IDs_update_centroids=specific_IDs_update_centroids, + ) + + def _normalize_cutout_bbox(self, cutout_bbox): + """Normalize cutout_bbox to always be 3D (6 values) for 3D data. + + Automatically expands 2D bbox (4 values: y_start, x_start, y_end, x_end) + to 3D bbox (6 values: z_start, y_start, x_start, z_end, y_end, x_end) + covering all z-slices. + """ + if self.is3D: + if len(cutout_bbox) == 4: + # 2D bbox: expand to 3D with full z range + y_start, x_start, y_end, x_end = cutout_bbox + return (0, y_start, x_start, self.lab.shape[0], y_end, x_end) + elif len(cutout_bbox) != 6: + raise ValueError( + 'For 3D labels, cutout_bbox should have 4 values (2D) or 6 values (3D): ' + f'got {len(cutout_bbox)}.' + ) + else: + if len(cutout_bbox) != 4: + raise ValueError( + 'For 2D labels, cutout_bbox should have 4 values (y_start, x_start, y_end, x_end), ' + f'got {len(cutout_bbox)}.' + ) + return cutout_bbox + + def _get_centroid_df_from_df(self): + if self.acdc_df is None or len(self.acdc_df) == 0: + return {} + + centroid_cols = ['y_centroid', 'x_centroid'] + if self.is3D and 'z_centroid' in self.acdc_df.columns: + centroid_cols = ['z_centroid', 'y_centroid', 'x_centroid'] + + if not set(centroid_cols).issubset(self.acdc_df.columns): + return {} + + if 'Cell_ID' in self.acdc_df.columns: + centroid_df = self.acdc_df.set_index('Cell_ID')[centroid_cols] + elif 'ID' in self.acdc_df.columns: + centroid_df = self.acdc_df.set_index('ID')[centroid_cols] + else: + centroid_df = self.acdc_df[centroid_cols] + + return { + int(ID): tuple(values) + for ID, values in centroid_df.iterrows() + } + + def _get_bbox_centers_mapper( + self, objs=None, IDs_to_include=None, IDs_to_exclude=None + ): + if objs is None and not self._rp: + return {} + + if objs is None: + if IDs_to_include is None: + IDs_to_include = ( + self.IDs_set.difference(IDs_to_exclude) + if IDs_to_exclude is not None else self.IDs_set + ) + ids = set(IDs_to_include) + objs = [obj for obj in self._rp if obj.label in ids] + + if not objs: + return {} + + ndim = 2 if not self.is3D else 3 + labels = np.empty(len(objs), dtype=int) + bboxes = np.empty((len(objs), ndim * 2), dtype=float) + for i, obj in enumerate(objs): + labels[i] = obj.label + bboxes[i] = obj.bbox + + centers = (bboxes[:, :ndim] + bboxes[:, ndim:]) / 2.0 + return { + int(label): tuple(center) + for label, center in zip(labels, centers) + } + + def precache_centroids(self): + centroid_df = self._get_centroid_df_from_df() + IDs_from_df = set(centroid_df) + IDs_missing_centroid = self.IDs_set.difference(IDs_from_df) + bbox_centers_mapper = self._get_bbox_centers_mapper( + IDs_to_include=IDs_missing_centroid + ) + self._centroid_mapper = {**bbox_centers_mapper, **centroid_df} + self._centroid_IDs_exact = IDs_from_df + + def set_attributes(self, deleted_IDs=None, update_centroid_mapper=True): + self.ID_to_idx = {obj.label: idx for idx, obj in enumerate(self._rp)} + # Update IDs and IDs_set separately and explicitly + self.IDs_set = set(self.ID_to_idx) + self.IDs = list(self.IDs_set) + + if not update_centroid_mapper: + return + if deleted_IDs is not None: + for ID in deleted_IDs: + self._centroid_mapper.pop(ID, None) + self._centroid_IDs_exact.discard(ID) + else: + self._centroid_mapper = { + ID: centroid + for ID, centroid in self._centroid_mapper.items() + if ID in self.IDs_set + } + self._centroid_IDs_exact.intersection_update(self.IDs_set) + + def get_obj_from_ID(self, ID, warn=True): + idx = self.ID_to_idx.get(ID, None) + if idx is not None: + return self._rp[idx] + else: + if warn: + # get caller info + debugutils.print_call_stack() + print(f"Warning: Object with ID {ID} not found in regionprops.") + return None + + def delete_IDs(self, IDs_to_delete: set[int], update_other_attrs=True): + if not IDs_to_delete: + return + + self._rp = [ + obj for obj in self._rp if obj.label not in IDs_to_delete + ] + + if not update_other_attrs: + return + self.set_attributes(deleted_IDs=IDs_to_delete) + + def _get_IDs_to_update_centroids( + self, lab, objs, specific_IDs_update_centroids=None + ): + if specific_IDs_update_centroids is not None: + return set(specific_IDs_update_centroids) + + obj_to_update = set() + for obj in objs: + has_to_update = False + ID = obj.label + old_centroid = self._centroid_mapper.get(ID, None) + if old_centroid is not None: + rounded_centroid = tuple(np.round(old_centroid).astype(int)) + try: + ID_lab = lab[rounded_centroid] + except Exception: + has_to_update = True + else: + if ID_lab != ID: + has_to_update = True + else: + has_to_update = True + + if has_to_update: + obj_to_update.add(ID) + + return obj_to_update + + def update_regionprops( + self, lab, specific_IDs_update_centroids=None, + update_centroids=True + ): + old_rp_by_id = {obj.label: obj for obj in self._rp} + + new_rp = _acdc_regionprops_factory(lab) + + if update_centroids: + # Verify that the cached centroid is still inside the object mask. + obj_to_update = self._get_IDs_to_update_centroids( + lab, new_rp, + specific_IDs_update_centroids=specific_IDs_update_centroids + ) + + bbox_centers_mapper = self._get_bbox_centers_mapper( + objs=[obj for obj in new_rp if obj.label in obj_to_update] + ) + + # update centroids + self._centroid_mapper.update(bbox_centers_mapper) + + # remove from exact set if we updated the centroid + self._centroid_IDs_exact.difference_update(obj_to_update) + + for obj in new_rp: + self._copy_custom_rp_attributes(obj, old_rp_by_id.get(obj.label)) + + self._rp = new_rp + self.lab = lab + self.set_attributes() + self._sync_initialized_slice_rps_via_update( + specific_IDs_update_centroids=specific_IDs_update_centroids + ) + self._sync_initialized_proj_rps_via_update( + specific_IDs_update_centroids=specific_IDs_update_centroids + ) + + def _copy_custom_rp_attributes(self, new_obj, old_obj): + if old_obj is None: + return + new_obj.dead = getattr(old_obj, 'dead', False) + new_obj.excluded = getattr(old_obj, 'excluded', False) + + def _get_bbox_slices(self, bbox, depth_axis=None): + ndim = self.lab.ndim + if len(bbox) != ndim * 2: + raise ValueError( + f'Expected a bounding box with {ndim*2} values, ' + f'got {len(bbox)}.' + ) + + return tuple( + slice(int(bbox[dim]), int(bbox[dim+ndim])) for dim in range(ndim) + ) + + def _translate_cutout_regionprop(self, obj, offset, lab): + offset_arr = np.asarray(offset) + centroid = obj.centroid + translated_slice = tuple( + slice( + obj._slice[dim].start + offset_arr[dim], + obj._slice[dim].stop + offset_arr[dim], + ) + for dim in range(obj._ndim) + ) + translated_bbox = tuple( + [slc.start for slc in translated_slice] + + [slc.stop for slc in translated_slice] + ) + translated_centroid = tuple( + coord + offset_arr[dim] + for dim, coord in enumerate(centroid) + ) + + obj._label_image = lab + obj._slice = translated_slice + obj.slice = translated_slice + obj._offset = np.zeros_like(offset_arr) + obj._cache['slice'] = translated_slice + obj._cache['bbox'] = translated_bbox + obj._cache['centroid'] = translated_centroid + return obj + + def _get_separate_obj_regionprops(self, lab, IDs): + IDs = tuple(int(ID) for ID in IDs) + if not IDs: + return {} + + mask = np.isin(lab, IDs) + if not np.any(mask): + return {} + + isolated_lab = np.zeros_like(lab) + isolated_lab[mask] = lab[mask] + return { + obj.label: obj + for obj in _acdc_regionprops_factory(isolated_lab) + if obj.label in IDs + } + + def _is_bbox_touching_cutout_border(self, bbox, shape): + ndim = len(shape) + for dim in range(ndim): + if bbox[dim] == 0 or bbox[dim+ndim] == shape[dim]: + return True + return False + + def _obj_intersects_bbox(self, obj, bbox): + ndim = self.lab.ndim + obj_bbox = obj.bbox + for dim in range(ndim): + start = max(int(obj_bbox[dim]), int(bbox[dim])) + stop = min(int(obj_bbox[dim+ndim]), int(bbox[dim+ndim])) + if start >= stop: + return False + + return True + + def _get_old_cutout_IDs_from_rp(self, cutout_bbox): + return { + obj.label for obj in self._rp + if self._obj_intersects_bbox(obj, cutout_bbox) + } + + def _set_label_image(self, lab, objs=None, clear_cache=False): + if lab is None: + return + + self.lab = lab + if objs is None: + objs = self._rp + + for obj in objs: + obj._label_image = lab + if clear_cache: + obj._cache.clear() + + def update_regionprops_via_assignments( + self, assignments:dict[int, int], lab + ): + """If the lab is completely the same, but only ID changes/swaps have been made + + Parameters + ---------- + assignments : dict[int, int] + key: old ID, + value: new ID + lab : np.ndarray, optional + Updated label image. When provided, regionprops objects are rebound + to this image so properties such as ``image`` stay consistent after + the ID remap. + """ + active_assignments = { + int(old_ID): int(new_ID) + for old_ID, new_ID in assignments.items() + if old_ID in self.IDs_set and old_ID != new_ID + } + if not active_assignments: + self._set_label_image(lab) + self._sync_initialized_slice_rps_via_assignments({}) + self._sync_initialized_proj_rps_via_assignments({}) + return + + # if not active_assignments: + # if lab is not None: + # self._set_label_image(lab) + # return + + # remapped_IDs = set() + # for obj in self._rp: + # old_ID = obj.label + # new_ID = active_assignments.get(old_ID, old_ID) + # if new_ID in remapped_IDs: + # raise ValueError( + # 'Assignments would create duplicate IDs in regionprops. ' + # 'Use a full regionprops recomputation for merges.' + # ) + # remapped_IDs.add(new_ID) + + centroid_mapper = { + active_assignments.get(ID, ID): centroid + for ID, centroid in self._centroid_mapper.items() + # if active_assignments.get(ID, ID) in remapped_IDs + } + centroid_IDs_exact = { + active_assignments.get(ID, ID) + for ID in self._centroid_IDs_exact + # if active_assignments.get(ID, ID) in remapped_IDs + } + + # Rebind first so any property access during remap sees the current lab. + self._set_label_image(lab) + + for obj in self._rp: + old_ID = obj.label + new_ID = active_assignments.get(old_ID, old_ID) + obj.label = new_ID + # if obj.area == 0: + # # if area is 0, centroid is not defined and we should not trust the cached one + # print("area 0...") + + self._centroid_mapper = centroid_mapper + self._centroid_IDs_exact = centroid_IDs_exact + self.set_attributes(update_centroid_mapper=False) # update the mapper + self._sync_initialized_slice_rps_via_assignments(active_assignments) + self._sync_initialized_proj_rps_via_assignments(active_assignments) + + def update_regionprops_via_deletions( + self, IDs_to_delete: set[int] + ): + """If the lab is completely the same, but only some IDs have been deleted + + Parameters + ---------- + IDs_to_delete : set[int] + IDs to delete + """ + IDs_to_delete = set(IDs_to_delete).intersection(self.IDs_set) + if not IDs_to_delete: + return + self._rp = [obj for obj in self._rp if obj.label not in IDs_to_delete] + self.set_attributes(deleted_IDs=IDs_to_delete) # for updating the IDs to indx, centroid mapper + self._sync_initialized_slice_rps_via_deletions(IDs_to_delete) + self._sync_initialized_proj_rps_via_deletions(IDs_to_delete) + + def update_regionprops_via_cutout( + self, lab, cutout_bbox, specific_IDs=None, depth_axis=None + ): + """Only relabels the regionprops of a specific cutout. + Is only faster for small cutouts. I dont have a number, but I would say + less than 30% of total image size. + + Parameters + ---------- + cutout_lab : np.ndarray + The labeled cutout image. + cutout_bbox : tuple[int, int, int, int] + The bounding box of the cutout in the format (min_row, min_col, max_row, max_col). + """ + if specific_IDs is not None and not isinstance(specific_IDs, (list, set, np.ndarray, tuple)): + specific_IDs = {specific_IDs} + elif specific_IDs is not None: + specific_IDs = set(specific_IDs) + + self.lab = lab + # Normalize bbox to 3D (expands 2D bbox to full z-range) + cutout_bbox = self._normalize_cutout_bbox(cutout_bbox) + cutout_slices = self._get_bbox_slices(cutout_bbox, depth_axis=depth_axis) + new_cutout = lab[cutout_slices] + old_cutout_IDs = self._get_old_cutout_IDs_from_rp(cutout_bbox) + rp_cutout_new = _acdc_regionprops_factory(new_cutout) + new_cutout_IDs = set(obj.label for obj in rp_cutout_new) + + if not old_cutout_IDs and not new_cutout_IDs: + return + + target_IDs = ( + old_cutout_IDs.union(new_cutout_IDs) + if specific_IDs is None + else old_cutout_IDs.union(new_cutout_IDs).intersection(specific_IDs) + ) + + deleted_target_IDs = old_cutout_IDs.difference(new_cutout_IDs).intersection( + target_IDs + ) + + refreshed_IDs = new_cutout_IDs.intersection(target_IDs) + + conflicting_IDs = refreshed_IDs.difference(old_cutout_IDs).intersection( + self.IDs_set.difference(old_cutout_IDs) + ) + if conflicting_IDs: + raise ValueError( + 'Cutout update would reuse IDs that already belong to objects ' + 'outside the cutout. Use a full regionprops recomputation.' + ) + + old_rp_by_id = {obj.label: obj for obj in self._rp} + IDs_to_replace = old_cutout_IDs.intersection(target_IDs) + unaffected_rp = [obj for obj in self._rp if obj.label not in IDs_to_replace] + + offset = tuple(s.start for s in cutout_slices) + + border_touching_IDs = { + obj.label + for obj in rp_cutout_new + if obj.label in refreshed_IDs + and self._is_bbox_touching_cutout_border(obj.bbox, new_cutout.shape) + } + separate_objs = self._get_separate_obj_regionprops(lab, border_touching_IDs) + + new_objs = [] + updated_centroid_IDs = set() + for obj in rp_cutout_new: + ID = obj.label + if ID not in refreshed_IDs: + continue + if ID in border_touching_IDs: + # edge case: ID changed is outside the cutout + new_obj = separate_objs.get(ID) + if new_obj is None: + continue + else: + new_obj = self._translate_cutout_regionprop(obj, offset, lab) + + self._copy_custom_rp_attributes(new_obj, old_rp_by_id.get(ID)) + new_objs.append(new_obj) + updated_centroid_IDs.add(ID) + + for ID in deleted_target_IDs: + self._centroid_mapper.pop(ID, None) + self._centroid_IDs_exact.discard(ID) + + if updated_centroid_IDs: + obj_to_update = self._get_IDs_to_update_centroids( + lab, new_objs, + specific_IDs_update_centroids=updated_centroid_IDs + ) + + self._centroid_mapper.update( + self._get_bbox_centers_mapper( + objs=[obj for obj in new_objs if obj.label in obj_to_update] + ) + ) + self._centroid_IDs_exact.difference_update(obj_to_update) + + self._rp = unaffected_rp + new_objs + self._set_label_image(lab) + self.set_attributes(update_centroid_mapper=False) + self._sync_initialized_slice_rps_via_update( + specific_IDs_update_centroids=target_IDs + ) + self._sync_initialized_proj_rps_via_update( + specific_IDs_update_centroids=target_IDs, + cutout_bbox=cutout_bbox, + ) + + def get_centroid(self, ID, exact=False): + if exact and ID not in self._centroid_IDs_exact: + obj = self.get_obj_from_ID(ID) + centroid = obj.centroid + try: + int(centroid[0]) + except (TypeError, ValueError): + print(f"Warning: Centroid for ID {ID} is not a valid coordinate: {centroid}. " + f"Object size: {obj.bbox}. Returning None.") + return None + self._centroid_mapper[ID] = centroid + self._centroid_IDs_exact.add(ID) + return centroid + + centroid = self._centroid_mapper.get(ID, None) + if centroid is None: + # add centroid to mapper if not found + objs = [self.get_obj_from_ID(ID)] + bbox_centers_mapper = self._get_bbox_centers_mapper(objs=objs) + self._centroid_mapper.update(bbox_centers_mapper) + centroid = self._centroid_mapper.get(ID, None) + return centroid + + def copy(self): + new_instance = acdcRegionprops( + self.lab, precache_centroids=False + ) + new_instance._rp = [obj for obj in self._rp] + new_instance._centroid_mapper = self._centroid_mapper.copy() + new_instance._centroid_IDs_exact = self._centroid_IDs_exact.copy() + for slicing, slice_number, rp in self._iter_initialized_slice_rps(): + new_instance._slice_rps[slicing][slice_number] = rp.copy() + for slicing, kind, rp in self._iter_initialized_proj_rps(): + new_instance._proj_rps[slicing][kind] = rp.copy() + for slicing, proj_labs in self._proj_labs.items(): + for kind, lab_proj in proj_labs.items(): + new_instance._proj_labs[slicing][kind] = lab_proj.copy() + new_instance.set_attributes(update_centroid_mapper=False) + return new_instance \ No newline at end of file diff --git a/cellacdc/segm.py b/cellacdc/segm.py index c75c2589d..293d13ece 100755 --- a/cellacdc/segm.py +++ b/cellacdc/segm.py @@ -518,7 +518,7 @@ def main(self): model_name = win.selectedModel - if model_name == 'thresholding': + if model_name in ('thresholding', 'Automatic thresholding'): win = apps.QDialogAutomaticThresholding( parent=self, isSegm3D=self.isSegm3D ) @@ -528,6 +528,9 @@ def main(self): return self.model_kwargs = win.segment_kwargs + if model_name == 'Automatic thresholding': + model_name = 'thresholding' + self.log(f'Downloading {model_name} (if needed)...') self.downloadWin = apps.downloadModel(model_name, parent=self) self.downloadWin.download() diff --git a/cellacdc/segm_utils.py b/cellacdc/segm_utils.py index 790e59719..880defd9e 100644 --- a/cellacdc/segm_utils.py +++ b/cellacdc/segm_utils.py @@ -3,11 +3,10 @@ import time from .core import segm_model_segment, post_process_segm from .features import custom_post_process_segm -from . import io, plot +from . import io, plot, regionprops import inspect - import os # for dbug import json # for dbug @@ -38,6 +37,22 @@ def find_overlap(lab_1, lab_2): return ID_overlap +def get_best_overlapping_label(label_img, obj, allowed_labels): + allowed_labels = set(allowed_labels) + if len(allowed_labels) == 0: + return None + + overlapping_labels = label_img[obj.slice][obj.image] + if overlapping_labels.size == 0: + return None + + overlapping_labels = overlapping_labels[np.isin(overlapping_labels, tuple(allowed_labels))] + if overlapping_labels.size == 0: + return None + + labels, counts = np.unique(overlapping_labels, return_counts=True) + return labels[np.argmax(counts)] + def get_obj_from_rps(rps, ID): for obj in rps: if obj.label == ID: @@ -163,10 +178,18 @@ def boxes_overlap(bbox1, bbox2): # # Use np.unique once on the combined array # return np.unique(border_labels[border_labels != 0]) -def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, - win, posData, distance_filler_growth=1, +def single_cell_seg(model, prev_lab, curr_lab, curr_img, + IDs, new_unique_ID, + posData, distance_filler_growth=1, overlap_threshold=0.5, padding=0.4, export_bbox_for_training=False, + model_kwargs=None, + preproc_recipe=None, + applyPostProcessing=False, + standardPostProcessKwargs=None, + customPostProcessFeatures=None, + customPostProcessGroupedFeatures=None, + debug=False, ): """ Function to segment single cells in the current frame using the previous frame segmentation as a reference. @@ -178,12 +201,17 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, curr_img: current frame image IDs: list of IDs of the cells to segment new_unique_ID: ID to start labeling new cells - win: from the gui window which sets model params posData: position data (see rest of acdc) distance_filler_growth: distance to grow the other IDs to fill the background overlap_threshold: minimum overlap percentage to consider a cell already segmented padding: padding around the cell to segment export_bbox_for_training: if True, export bounding boxes for training model + model_kwargs: keyword arguments to pass to the segmentation model + preproc_recipe: preprocessing recipe to apply before segmentation + applyPostProcessing: if True, apply post-processing to the segmentation + standardPostProcessKwargs: keyword arguments for standard post-processing + customPostProcessFeatures: custom features for post-processing segmentation + customPostProcessGroupedFeatures: custom grouped features for post-processing Returns: curr_lab: current frame segmentation with the segmented cells @@ -194,12 +222,10 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, if export_bbox_for_training: bboxs_for_debug = [] - model_kwargs = win.model_kwargs - preproc_recipe = win.preproc_recipe - applyPostProcessing = win.applyPostProcessing - standardPostProcessKwargs = win.standardPostProcessKwargs - customPostProcessFeatures = win.customPostProcessFeatures - customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures + if model_kwargs is None: + model_kwargs = {} + if standardPostProcessKwargs is None: + standardPostProcessKwargs = {} prev_rp = skimage.measure.regionprops(prev_lab) prev_lab_shape = prev_lab.shape @@ -210,23 +236,37 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, assigned_IDs = [] uses_diameter = inspect.signature(model.segment).parameters.get('diameter', None) is not None - for IDs, bbox in zip(IDs_bboxs, bboxs): + + if debug: + imgs_to_show = { + i: [] for i in range(len(bboxs)) + } + for i, (IDs, bbox) in enumerate(zip(IDs_bboxs, bboxs)): box_x_min, box_x_max, box_y_min, box_y_max = bbox box_curr_img = curr_img[box_x_min:box_x_max, box_y_min:box_y_max].copy() box_curr_lab = curr_lab[box_x_min:box_x_max, box_y_min:box_y_max] + + if debug: + imgs_to_show[i].append(box_curr_img.copy()) + imgs_to_show[i].append(box_curr_lab.copy()) box_curr_lab_other_IDs = box_curr_lab.copy() IDs = np.array(IDs) box_curr_lab_other_IDs[np.isin(box_curr_lab_other_IDs, IDs)] = 0 box_curr_lab_other_IDs_grown = skimage.segmentation.expand_labels(box_curr_lab_other_IDs, distance=distance_filler_growth) + if debug: + imgs_to_show[i].append(box_curr_lab_other_IDs_grown.copy()) # Fill other IDs with random samples from the background indices_to_fill = np.where(box_curr_lab_other_IDs_grown != 0) box_background = box_curr_img[box_curr_lab_other_IDs_grown==0] random_samples = np.random.choice(box_background, size=indices_to_fill[0].shape, replace=True) box_curr_img[indices_to_fill] = random_samples + + if debug: + imgs_to_show[i].append(box_curr_img.copy()) # Run model, give it the diameter of cell if possible if uses_diameter: @@ -248,6 +288,9 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, posData=posData, ) + if debug: + imgs_to_show[i].append(box_model_lab.copy()) + if export_bbox_for_training: bboxs_for_debug.append([IDs, bbox, box_model_lab.copy(), box_curr_lab.copy()]) @@ -269,14 +312,14 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, ### maybe add roi extension if cells are deleted... # Find the overlap between the model segmentation and the other IDs - overlap = find_overlap(box_model_lab, box_curr_lab_other_IDs) + overlaps = find_overlap(box_model_lab, box_curr_lab_other_IDs) # Set overlapping regions to 0, so already segmented cells are not overwritten - for ID, overlap_perc in overlap: - if overlap_perc > overlap_threshold: - box_model_lab[box_model_lab == ID] = 0 + IDs_to_filter = [ID for ID, overlap_perc in overlaps if overlap_perc > overlap_threshold] + if IDs_to_filter: + box_model_lab[np.isin(box_model_lab, IDs_to_filter)] = 0 - rp_model_lab = skimage.measure.regionprops(box_model_lab) + rp_model_lab = regionprops.acdcRegionprops(box_model_lab,precache_centroids=False) for obj in rp_model_lab: box_curr_lab_other_IDs[box_model_lab == obj.label] = new_unique_ID assigned_IDs.append(new_unique_ID) @@ -321,4 +364,6 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, with open(json_filepath, 'w') as f: json.dump(loaded_dict, f, indent=4) + if debug: + return curr_lab, assigned_IDs, IDs_bboxs, bboxs, imgs_to_show return curr_lab, assigned_IDs, IDs_bboxs, bboxs \ No newline at end of file diff --git a/cellacdc/segmentation.py b/cellacdc/segmentation.py deleted file mode 100644 index 9bfba56ad..000000000 --- a/cellacdc/segmentation.py +++ /dev/null @@ -1,139 +0,0 @@ -import numpy as np - -import skimage.segmentation -import skimage.measure - -import cv2 - -def _find_contours_2D( - image, bbox_lower_coords=(0, 0), all=False, closed=True - ): - mode = cv2.RETR_CCOMP if all else cv2.RETR_EXTERNAL - contours, _ = cv2.findContours(image, mode, cv2.CHAIN_APPROX_NONE) - - if all: - all_contours = [ - np.squeeze(contour, axis=1)+bbox_lower_coords - for contour in contours - ] - if closed: - all_contours = [ - np.vstack((contour, contour[0])) for contour in contours - ] - return all_contours - else: - contour = np.squeeze(contours[0], axis=1) - if closed: - contour = np.vstack((contour, contour[0])) - contour = contour + bbox_lower_coords - return contour - -def find_obj_contour( - obj: skimage.measure._regionprops.RegionProperties, all=False, - local=False, do_z_max_proj=False, closed=True - ): - is3D = obj.image.ndim == 3 - bbox_y_idx = 1 if is3D else 0 - - if local: - bbox_lower_coords=(0, 0) - else: - min_y, min_x = obj.bbox[bbox_y_idx:bbox_y_idx+2] - bbox_lower_coords = (min_x, min_y) - - if is3D and do_z_max_proj: - is3D = False - obj_image = obj.max(axis=0).astype(np.uint8) - else: - obj_image = obj.image.astype(np.uint8) - - kwargs = { - 'bbox_lower_coords': bbox_lower_coords, - 'all':all, 'closed': closed - } - if is3D: - contours = [ - _find_contours_2D(image_z, **kwargs) for image_z in obj_image - ] - else: - contours = _find_contours_2D(obj_image, **kwargs) - return contours - -def find_contours( - label_img, connectivity=1, mode='thick', background=0, - return_coords=False, **kwargs - ): - """Return bool array where boundaries between labeled regions are True. - If `return_coords` is True then return also a list of objects' contours - coordinates. - - Parameters - ---------- - label_img : (M, N[, P]) ndarray - An array in which different regions are labeled with either different - integers or boolean values. - connectivity : int, optional - int in {1, ..., `label_img.ndim`}, optional - A pixel is considered a boundary pixel if any of its neighbors - has a different label. `connectivity` controls which pixels are - considered neighbors. A connectivity of 1 (default) means - pixels sharing an edge (in 2D) or a face (in 3D) will be - considered neighbors. A connectivity of `label_img.ndim` means - pixels sharing a corner will be considered neighbors. Default is 1. - mode : str, optional - How to mark the boundaries: - - thick: any pixel not completely surrounded by pixels of the - same label (defined by `connectivity`) is marked as a boundary. - This results in boundaries that are 2 pixels thick. - - inner: outline the pixels *just inside* of objects, leaving - background pixels untouched. - - outer: outline pixels in the background around object - boundaries. When two objects touch, their boundary is also - marked. - - subpixel: return a doubled image, with pixels *between* the - original pixels marked as boundary where appropriate., - - By default 'thick' - background : int, optional - For modes 'inner' and 'outer', a definition of a background - label is required. See `mode` for descriptions of these two, - by default 0 - return_coords : bool, optional - If ``True``, also return a list of objects' contours coordinates, - by default False - kwargs : dict, optional - Additional arguments passed `acdctools.segmentation.find_obj_contour` - function. This function uses the opencv find contours function - `cv2.findContours`. Used only if `mode='inner'`. - - Returns - ------- - boundaries : ndarray of bool, same shape as `label_img` - A bool image where ``True`` represents a boundary pixel. For - `mode` equal to 'subpixel', ``boundaries.shape[i]`` is equal - to ``2 * label_img.shape[i] - 1`` for all ``i`` (a pixel is - inserted in between all other pairs of pixels). - contours_coords: list of ndarray - A list of ndarrays with shape (N, n) where `n` is the number of - dimensions of `label_img` and `N` is the number of points in each - object's contour. The list contains one ndarray per object in - `label_img`. - The ordering of columns follows the numpy's order of dimensions - convention, e.g., for 2-D, the first and second column are the - y and x coordinates, respectively. - Only provided if `return_coords` is True. - """ - boundaries = skimage.segmentation.find_boundaries( - label_img, connectivity=connectivity, mode=mode, background=background - ) - if not return_coords: - return boundaries - - is2D = label_img.ndim == 2 - rp = skimage.measure.regionprops(label_img) - contours_coords = [] - for obj in rp: - if mode == 'inner' and is2D: - pass - else: - pass diff --git a/cellacdc/trackers/CellACDC/CellACDC_tracker.py b/cellacdc/trackers/CellACDC/CellACDC_tracker.py index 07cfe841b..c02bddb99 100755 --- a/cellacdc/trackers/CellACDC/CellACDC_tracker.py +++ b/cellacdc/trackers/CellACDC/CellACDC_tracker.py @@ -5,82 +5,137 @@ import numpy as np from skimage.measure import regionprops +from cellacdc.regionprops import acdcRegionprops from skimage.segmentation import relabel_sequential -from cellacdc import core, printl +from cellacdc import core, printl, debugutils + +try: + from cellacdc.precompiled.precompiled_functions import ( + calc_IoA_matrix_2D as _calc_IoA_matrix_2D_cython, + calc_IoA_matrix_3D as _calc_IoA_matrix_3D_cython, + ) + _HAS_CYTHON_IOA = True + print('tracking: imported precompiled IoA helpers.') +except ImportError: + _HAS_CYTHON_IOA = False + print('[WARNING]: tracking could not import precompiled IoA helpers, falling back to NumPy implementation.') DEBUG = False +def _normalize_specific_IDs(specific_IDs): + if specific_IDs is None: + return None + if isinstance(specific_IDs, (list, tuple, set, np.ndarray)): + return set(specific_IDs) + return {specific_IDs} + +def _filter_subset_assignments(old_IDs, tracked_IDs, all_curr_IDs, specific_IDs): + if specific_IDs is None: + return old_IDs, tracked_IDs + + selected_curr_IDs = set(specific_IDs) + other_curr_IDs = set(all_curr_IDs).difference(selected_curr_IDs) + filtered_old_IDs = [] + filtered_tracked_IDs = [] + for old_ID, tracked_ID in zip(old_IDs, tracked_IDs): + if tracked_ID in other_curr_IDs: + continue + filtered_old_IDs.append(old_ID) + filtered_tracked_IDs.append(tracked_ID) + + return filtered_old_IDs, filtered_tracked_IDs + def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None, - denom:str='area_prev', IDs=None): - # maybe its faster to calculate IoU not via mask but via area1 / (area1 + area2 - intersection) - IDs_prev = [] - if IDs_curr_untracked is None: + specific_IDs=None, + denom:str='area_prev'): + specific_IDs = _normalize_specific_IDs(specific_IDs) + if IDs_curr_untracked is None and isinstance(rp, acdcRegionprops): + IDs_curr_untracked = rp.IDs + elif IDs_curr_untracked is None: IDs_curr_untracked = [obj.label for obj in rp] + elif not isinstance(IDs_curr_untracked, list): + IDs_curr_untracked = list(IDs_curr_untracked) - IoA_matrix = np.zeros((len(rp), len(prev_rp))) - rp_mapper = {obj.label: obj for obj in rp} - idx_mapper = {obj.label: i for i, obj in enumerate(rp)} - # For each ID in previous frame get IoA with all current IDs - # Rows: IDs in current frame, columns: IDs in previous frame - - ### just an idea for having area_curr as a denom possibility: just switch all around... - # if denom == 'area_curr': - # # switch prev with curr - # prev_lab_temp, prev_rp_temp = prev_lab.copy(), prev_rp.copy() - # prev_lab, prev_rp = lab.copy, rp.copy() - # lab, rp = prev_lab_temp, prev_rp_temp - - if not denom in ['area_prev', 'union']: + if specific_IDs is not None: + IDs_curr_untracked = [ + ID for ID in IDs_curr_untracked if ID in specific_IDs + ] + + if isinstance(prev_rp, acdcRegionprops): + IDs_prev = prev_rp.IDs + + else: + IDs_prev = [obj.label for obj in prev_rp] + + if not IDs_curr_untracked: + return np.zeros((0, len(prev_rp))), IDs_curr_untracked, IDs_prev + + if denom not in ('area_prev', 'union'): raise ValueError( "Invalid denom value. Use 'area_prev' or 'union'." ) - # prev_label_positions = {ID_prev: np.where(prev_lab == ID_prev)[0] for ID_prev in set(prev_lab) if ID_prev != 0} - # if denom == 'union': - # temp_lab = np.zeros(lab.shape, dtype=bool) - for j, obj_prev in enumerate(prev_rp): - ID_prev = obj_prev.label - IDs_prev.append(ID_prev) - # if IDs is not None and ID_prev not in IDs: - # continue + if _HAS_CYTHON_IOA: + use_union = denom == 'union' + curr_IDs_arr = np.array(IDs_curr_untracked, dtype=np.uint32) + prev_IDs_arr = np.array(IDs_prev, dtype=np.uint32) + prev_areas_arr = np.array([obj.area for obj in prev_rp], dtype=np.uint32) + if use_union: + rp_mapper = {obj.label: obj for obj in rp} + curr_areas_arr = np.array( + [rp_mapper[ID].area for ID in IDs_curr_untracked], dtype=np.uint32 + ) + else: + curr_areas_arr = np.empty(0, dtype=np.uint32) + lab_u32 = np.asarray(lab, dtype=np.uint32) + prev_lab_u32 = np.asarray(prev_lab, dtype=np.uint32) + if lab.ndim == 2: + IoA_matrix = _calc_IoA_matrix_2D_cython( + lab_u32, prev_lab_u32, curr_IDs_arr, prev_IDs_arr, + prev_areas_arr, curr_areas_arr, use_union, + ) + else: + IoA_matrix = _calc_IoA_matrix_3D_cython( + lab_u32, prev_lab_u32, curr_IDs_arr, prev_IDs_arr, + prev_areas_arr, curr_areas_arr, use_union, + ) + return IoA_matrix, IDs_curr_untracked, IDs_prev - if denom == 'area_prev': # or denom == 'area_curr': + # --- pure-Python fallback (used when Cython extension is not compiled) --- + IoA_matrix = np.zeros((len(IDs_curr_untracked), len(prev_rp))) + rp_mapper = {obj.label: obj for obj in rp} + idx_mapper = {ID: i for i, ID in enumerate(IDs_curr_untracked)} + for j, obj_prev in enumerate(prev_rp): + if denom == 'area_prev': denom_val = obj_prev.area - - # Get intersecting IDs between current and object in previous frame intersect_IDs, intersects = np.unique( lab[obj_prev.slice][obj_prev.image], return_counts=True ) for intersect_ID, I in zip(intersect_IDs, intersects): - if intersect_ID == 0: - continue - - if I == 0: + if intersect_ID == 0 or I == 0: continue - if denom == 'union': + if intersect_ID not in rp_mapper: + continue obj_curr = rp_mapper[intersect_ID] - # temp_lab[obj_prev.slice][obj_prev.image] = True - # temp_lab[obj_curr.slice][obj_curr.image] = True - # denom_val = np.count_nonzero(temp_lab) - # temp_lab[:] = False denom_val = obj_prev.area + obj_curr.area - I if denom_val == 0: continue - - idx = idx_mapper[intersect_ID] - IoA = I/denom_val - IoA_matrix[idx, j] = IoA + idx = idx_mapper.get(intersect_ID) + if idx is None: + continue + IoA_matrix[idx, j] = I / denom_val return IoA_matrix, IDs_curr_untracked, IDs_prev def assign( IoA_matrix, IDs_curr_untracked, IDs_prev, IoA_thresh=0.4, aggr_track=None, IoA_thresh_aggr=0.4, daughters_list=None, - IDs=None): + specific_IDs=None): # Determine max IoA between IDs and assign tracked ID if IoA >= IoA_thresh if IoA_matrix.size == 0: return [], [] + max_IoA_col_idx = IoA_matrix.argmax(axis=1) unique_col_idx, counts = np.unique(max_IoA_col_idx, return_counts=True) counts_dict = dict(zip(unique_col_idx, counts)) @@ -187,7 +242,10 @@ def indexAssignment( remove_untracked=False, assign_unique_new_IDs=True, return_assignments=False, - IDs=None + dont_return_tracked_lab=False, + specific_IDs=None, + all_curr_IDs=None, + IDs=None, ): """Replace `old_IDs` in `lab` with `tracked_IDs` while making sure to avoid merging IDs. @@ -229,15 +287,25 @@ def indexAssignment( assignments: dict Returned only if `return_assignments` is True. """ + specific_IDs = _normalize_specific_IDs(specific_IDs) log_debugging( 'start', IDs_curr_untracked=IDs_curr_untracked, old_IDs=old_IDs ) - # Replace untracked IDs with tracked IDs and new IDs with increasing num + if all_curr_IDs is None: + all_curr_IDs = list(IDs_curr_untracked) + old_IDs, tracked_IDs = _filter_subset_assignments( + old_IDs, tracked_IDs, all_curr_IDs, specific_IDs + ) + + # Replace untracked IDs with tracked IDs and new IDs with increasing num. + # When tracking only a subset of current IDs, leave unrelated labels untouched. new_untracked_IDs = [ID for ID in IDs_curr_untracked if ID not in old_IDs] - tracked_lab = lab + + if not dont_return_tracked_lab: + tracked_lab = lab assignments = {} log_debugging( 'assign_unique', @@ -251,9 +319,10 @@ def indexAssignment( new_tracked_IDs = [ uniqueID+i for i in range(len(new_untracked_IDs)) ] - core.lab_replace_values( - tracked_lab, rp, new_untracked_IDs, new_tracked_IDs - ) + if not dont_return_tracked_lab: + core.lab_replace_values( + tracked_lab, rp, new_untracked_IDs, new_tracked_IDs + ) assignments.update(dict(zip(new_untracked_IDs, new_tracked_IDs))) log_debugging( 'new_untracked_and_assign_unique', @@ -271,9 +340,10 @@ def indexAssignment( new_tracked_IDs = [ uniqueID+i for i in range(len(new_IDs_in_trackedIDs)) ] - core.lab_replace_values( - tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs - ) + if not dont_return_tracked_lab: + core.lab_replace_values( + tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs + ) assignments.update(dict(zip(new_IDs_in_trackedIDs, new_tracked_IDs))) log_debugging( 'new_untracked_and_tracked', @@ -283,10 +353,15 @@ def indexAssignment( new_tracked_IDs=new_tracked_IDs ) if tracked_IDs: - core.lab_replace_values( - tracked_lab, rp, old_IDs, tracked_IDs, in_place=True - ) - assignments.update(dict(zip(old_IDs, tracked_IDs))) + if not dont_return_tracked_lab: + core.lab_replace_values( + tracked_lab, rp, old_IDs, tracked_IDs, in_place=True + ) + assignments.update({ + old_ID: tracked_ID + for old_ID, tracked_ID in zip(old_IDs, tracked_IDs) + if old_ID != tracked_ID + }) log_debugging( 'tracked', tracked_IDs=tracked_IDs, @@ -295,6 +370,8 @@ def indexAssignment( if not return_assignments: return tracked_lab + elif dont_return_tracked_lab: + return assignments else: return tracked_lab, assignments @@ -305,17 +382,30 @@ def track_frame( return_all=False, aggr_track=None, IoA_matrix=None, IoA_thresh_aggr=None, IDs_prev=None, return_prev_IDs=False, mother_daughters=None, denom_overlap_matrix = 'area_prev', - IDs=None + return_assignments=False, specific_IDs=None, dont_return_tracked_lab=False ): if not np.any(lab): # Skip empty frames return lab + all_curr_IDs = ( + list(IDs_curr_untracked) + if IDs_curr_untracked is not None else None + ) + if isinstance(rp, acdcRegionprops) and all_curr_IDs is None: + all_curr_IDs = rp.IDs + elif all_curr_IDs is None: + all_curr_IDs = [obj.label for obj in rp] + elif not isinstance(all_curr_IDs, list): + all_curr_IDs = list(all_curr_IDs) + if IoA_matrix is None: - IoA_matrix, IDs_curr_untracked, IDs_prev = calc_Io_matrix( + IoA_matrix, tracked_curr_IDs, IDs_prev = calc_Io_matrix( lab, prev_lab, rp, prev_rp, IDs_curr_untracked=IDs_curr_untracked, - denom=denom_overlap_matrix, IDs=IDs + denom=denom_overlap_matrix,specific_IDs=specific_IDs, ) + else: + tracked_curr_IDs = IDs_curr_untracked daughters_list = [] if mother_daughters: @@ -323,38 +413,61 @@ def track_frame( daughters_list.extend(daughters) old_IDs, tracked_IDs = assign( - IoA_matrix, IDs_curr_untracked, IDs_prev, + IoA_matrix, tracked_curr_IDs, IDs_prev, IoA_thresh=IoA_thresh, aggr_track=aggr_track, IoA_thresh_aggr=IoA_thresh_aggr, daughters_list=daughters_list, + specific_IDs=specific_IDs, ) if posData is None and unique_ID is None: unique_ID = max( - (max(IDs_prev, default=0), max(IDs_curr_untracked, default=0)) + (max(IDs_prev, default=0), max(all_curr_IDs, default=0)) ) + 1 elif unique_ID is None: # Compute starting unique ID setBrushID_func(useCurrentLab=True) unique_ID = posData.brushID+1 - if not return_all: + if not return_all and not return_assignments: tracked_lab = indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, + old_IDs, tracked_IDs, tracked_curr_IDs, lab.copy(), rp, unique_ID, assign_unique_new_IDs=assign_unique_new_IDs, + specific_IDs=specific_IDs, + all_curr_IDs=all_curr_IDs, + ) + elif dont_return_tracked_lab: + assignments = indexAssignment( + old_IDs, tracked_IDs, tracked_curr_IDs, + lab.copy(), rp, unique_ID, + assign_unique_new_IDs=assign_unique_new_IDs, + return_assignments=True, specific_IDs=specific_IDs, + dont_return_tracked_lab=True, + all_curr_IDs=all_curr_IDs, ) else: tracked_lab, assignments = indexAssignment( - old_IDs, tracked_IDs, IDs_curr_untracked, + old_IDs, tracked_IDs, tracked_curr_IDs, lab.copy(), rp, unique_ID, assign_unique_new_IDs=assign_unique_new_IDs, - return_assignments=return_all, + return_assignments=True, specific_IDs=specific_IDs, + all_curr_IDs=all_curr_IDs, ) # old_new_ids = dict(zip(old_IDs, tracked_IDs)) # for now not used, but could be useful in the future - if return_all: + if return_all and dont_return_tracked_lab: + # special case where we want to only get the assignments but need the rest too! + return IoA_matrix, assignments, tracked_IDs + elif return_all: return tracked_lab, IoA_matrix, assignments, tracked_IDs # remove tracked_IDs and change code in CellACDC_tracker.py if causing problems + elif dont_return_tracked_lab: + return assignments + elif return_assignments: + add_info = { + 'assignments': assignments, + } + return tracked_lab, add_info else: return tracked_lab diff --git a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py index 4bef3ec46..198e2d6ac 100644 --- a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py +++ b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py @@ -13,6 +13,26 @@ import cellacdc.core from ..CellACDC import CellACDC_tracker +from ..CellACDC.CellACDC_tracker import _normalize_specific_IDs + +from cellacdc._types import NotGUIParam + +def _format_tracking_result( + tracked_lab, + assignments, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ): + add_info = { + 'assignments': assignments, + 'to_track_tracked_objs_2nd_step': to_track_tracked_objs_2nd_step, + } + + if dont_return_tracked_lab: + return add_info + + return tracked_lab, add_info # always return extra info class SearchRangeUnits: values = ['pixels', 'micrometre'] @@ -114,11 +134,14 @@ def track_frame( overlap_threshold=0.4, search_range_unit: SearchRangeUnits='pixels', lost_IDs_search_range=10, - unique_ID: Integer=None + unique_ID: Integer=None, + specific_IDs: NotGUIParam=None, + dont_return_tracked_lab=False, + return_assignments=False, ): """Track two consecutive frames in two steps. First step based on `overlap_threshold` and second step tracks only lost objects to new - objects detemined at first step. + objects determined at first step. Parameters ---------- @@ -148,20 +171,30 @@ def track_frame( If not None, uses this as starting ID for all the untracked objects. If None, this will be calculated based on the two input frames. """ + specific_IDs = _normalize_specific_IDs(specific_IDs) to_track_tracked_objs_2nd_step = None prev_rp = skimage.measure.regionprops(prev_frame_lab) curr_rp = skimage.measure.regionprops(current_frame_lab) - tracked_lab_1st_step = CellACDC_tracker.track_frame( + tracked_lab_1st_step, add_info = CellACDC_tracker.track_frame( prev_frame_lab, prev_rp, current_frame_lab, curr_rp, IoA_thresh=overlap_threshold, - return_prev_IDs=False, - unique_ID=unique_ID + return_prev_IDs=False, + unique_ID=unique_ID, + specific_IDs=specific_IDs, + return_assignments=True, ) + assignments_step_1 = add_info['assignments'] + selected_tracked_IDs = None + if specific_IDs is not None: + selected_tracked_IDs = { + assignments_step_1.get(curr_ID, curr_ID) + for curr_ID in specific_IDs + } prev_rp_mapper = {obj.label: obj for obj in prev_rp} @@ -176,27 +209,43 @@ def track_frame( } if not lost_rp_mapper: - return tracked_lab_1st_step, to_track_tracked_objs_2nd_step + return _format_tracking_result( + tracked_lab_1st_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) new_rp_mapper = { obj.label: obj for obj in tracked_rp_1st_step + if ( + selected_tracked_IDs is None + or obj.label in selected_tracked_IDs + ) if prev_rp_mapper.get(obj.label) is None } - + if not new_rp_mapper: - return tracked_lab_1st_step, to_track_tracked_objs_2nd_step + return _format_tracking_result( + tracked_lab_1st_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) ndim = current_frame_lab.ndim lost_IDs_coords = np.zeros((len(lost_rp_mapper), ndim)) lost_IDs_idx_to_obj_mapper = {} for lost_idx, lost_obj in enumerate(lost_rp_mapper.values()): - lost_IDs_coords[lost_idx] = lost_obj.centroid + lost_IDs_coords[lost_idx] = lost_obj.centroid # we have overwritten RP so its always cached lost_IDs_idx_to_obj_mapper[lost_idx] = lost_obj new_IDs_coords = np.zeros((len(new_rp_mapper), ndim)) new_IDs_idx_to_obj_mapper = {} for new_idx, new_obj in enumerate(new_rp_mapper.values()): - new_IDs_coords[new_idx] = new_obj.centroid + new_IDs_coords[new_idx] = new_obj.centroid # we have overwritten RP so its always cached new_IDs_idx_to_obj_mapper[new_idx] = new_obj if search_range_unit == 'micrometre': @@ -229,21 +278,45 @@ def track_frame( tracked_objs_2nd_step.append(lost_IDs_idx_to_obj_mapper[i]) if not IDs_to_track: - return tracked_lab_1st_step, to_track_tracked_objs_2nd_step + return _format_tracking_result( + tracked_lab_1st_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) - tracked_lab_2nd_step = cellacdc.core.lab_replace_values( - tracked_lab_1st_step, - tracked_rp_1st_step, - IDs_to_track, - tracked_IDs_2nd_step - ) + if not dont_return_tracked_lab: + tracked_lab_2nd_step = cellacdc.core.lab_replace_values( + tracked_lab_1st_step, + tracked_rp_1st_step, + IDs_to_track, + tracked_IDs_2nd_step + ) + else: + tracked_lab_2nd_step = None if self._annot_obj_2nd_step: to_track_tracked_objs_2nd_step = ( objs_to_track, tracked_objs_2nd_step ) - return tracked_lab_2nd_step, to_track_tracked_objs_2nd_step + assignments_step_2 = dict(zip(IDs_to_track, tracked_IDs_2nd_step)) + for current_ID, tracked_ID in list(assignments_step_1.items()): + final_tracked_ID = assignments_step_2.get(tracked_ID) + if final_tracked_ID is not None: + assignments_step_1[current_ID] = final_tracked_ID + + for current_ID, tracked_ID in assignments_step_2.items(): + assignments_step_1.setdefault(current_ID, tracked_ID) + + return _format_tracking_result( + tracked_lab_2nd_step, + assignments_step_1, + to_track_tracked_objs_2nd_step, + return_assignments, + dont_return_tracked_lab, + ) def updateGuiProgressBar(self, signals): if signals is None: diff --git a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py index 8e47215a9..560376dbd 100644 --- a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py +++ b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py @@ -4,33 +4,12 @@ from cellacdc.core import getBaseCca_df, printl from cellacdc.myutils import checked_reset_index, checked_reset_index_Cell_ID import numpy as np -from skimage.measure import regionprops from tqdm import tqdm import pandas as pd -from cellacdc.myutils import exec_time from cellacdc._types import NotGUIParam import copy import cellacdc.debugutils as debugutils - -# def filter_cols(df): -# """ -# Filters the columns of a DataFrame based on a predefined set of column names. -# 'generation_num_tree', 'root_ID_tree', 'sister_ID_tree', 'parent_ID_tree', 'parent_ID_tree', 'emerg_frame_i', 'division_frame_i' -# plus any column that starts with 'sister_ID_tree' - -# Parameters: -# - df (pandas.DataFrame): The input DataFrame. - -# Returns: -# - pandas.DataFrame: The filtered DataFrame containing only the specified columns. -# """ -# lin_tree_cols = {'generation_num_tree', 'root_ID_tree', -# 'sister_ID_tree', 'parent_ID_tree', -# 'parent_ID_tree', 'emerg_frame_i', -# 'division_frame_i', 'is_history_known'} -# sis_cols = {col for col in df.columns if col.startswith('sister_ID_tree')} -# lin_tree_cols = lin_tree_cols | sis_cols -# return df[list(lin_tree_cols)] +from cellacdc.regionprops import acdcRegionprops as acdcRegionprops def reorg_sister_cells_for_export(lineage_tree_frame): """ @@ -60,45 +39,6 @@ def reorg_sister_cells_for_export(lineage_tree_frame): return lineage_tree_frame -# def reorg_sister_cells_inner_func(row): -# """ -# Reorganizes the sister cells in a row of a DataFrame. Used as an inner function for apply. - -# Parameters: -# - row (pandas.Series): The input row of the DataFrame (alredy filtered for the sister columns). -# Returns: -# - pandas.Series: The reorganized row with the sister cells. -# """ - -# values = [int(i) for i in row if i not in {0, -1} and not np.isnan(i)] or [-1] -# values = list(set(values)) -# return values - - -# def reorg_sister_cells_for_import(df): -# """ -# Reorganizes the sister cells for import. - -# This function takes a DataFrame `df` as input and performs the following steps: -# 1. Identifies the sister columns in the DataFrame. -# 2. Removes any values that are equal to 0 or -1 from the sister columns. (Which both represent no sister cell) -# 3. Converts the remaining values in the sister columns to a set. -# 4. Converts the set of values to a list if it is not empty, otherwise assigns [-1] to the sister column. (It actually shouldn't be empty, but just in case...) -# 5. Removes the sister columns from the DataFrame. And adds the list as the new 'sister_ID_tree' column. - -# Parameters: -# - df (pandas.DataFrame): The input DataFrame. - -# Returns: -# - df (pandas.DataFrame): The modified DataFrame with reorganized sister cells. -# """ -# sister_cols = [col for col in df.columns if col.startswith('sister_ID_tree')] # handling sister columns -# df.loc[:, 'sister_ID_tree'] = df[sister_cols].apply(reorg_sister_cells_inner_func, axis=1) -# sister_cols.remove('sister_ID_tree') -# df = df.drop(columns=sister_cols) -# df = checked_reset_index_Cell_ID(df) -# return df - def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh_instant=None): """ Identifies cells that have not undergone division based on the input IoA matrix. @@ -151,13 +91,8 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da else: should_remove_idx.append(False) - # printl(f'length of mother_daughters: {len(mother_daughters), len(should_remove_idx)}') mother_daughters = [mother_daughters[i] for i, remove in enumerate(should_remove_idx) if not remove] - # daughters_li = [] - # for _, daughters in mother_daughters: - # daughters_li.extend(daughters) - return aggr_track, mother_daughters def added_lineage_tree_to_cca_df(added_lineage_tree): @@ -241,33 +176,6 @@ def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked): return daughter_IDs -# def update_fam_dynamically(families, fixed_df, Cell_IDs_fixed=None): -# if Cell_IDs_fixed is None: -# Cell_IDs_fixed = fixed_df.index -# for idx, family in enumerate(families): -# # Keep only cellinfos where cell_id is in Cell_IDs_fixed -# families[idx] = [cellinfo for cellinfo in family if cellinfo[0] not in Cell_IDs_fixed] - -# families = [family for family in families if family] # Remove empty families -# handled_cells = set() -# for family in families: -# root_ID = family[0][0] # The first cell in the family is the root -# try: -# relevant_cells = fixed_df.loc[fixed_df['root_ID_tree'] == root_ID] -# except: -# printl(fixed_df['root_ID_tree']) -# for relevant_cell in relevant_cells.index: -# # Update the family with the generation number and root ID -# family.append((relevant_cell, relevant_cells.loc[relevant_cell, 'generation_num_tree'])) -# handled_cells.update(relevant_cells.index) - -# for cell_id in Cell_IDs_fixed: -# if cell_id not in handled_cells: -# # If the cell is not handled, create a new family for it -# families.append([(cell_id, fixed_df.loc[cell_id, 'generation_num_tree'])]) - -# return families - class normal_division_tracker: """ A class that tracks cell divisions in a video sequence. The tracker uses the Intersection over Area (IoA) metric to track cells and identify daughter cells. @@ -323,7 +231,8 @@ def __init__(self, self.tracked_video[0] = segm_video[0] def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, - IDs=None, unique_ID=None): + IDs=None, unique_ID=None, + return_assignments=False, specific_IDs=None, dont_return_tracked_lab=False): """ Tracks a single frame in the video sequence. @@ -342,42 +251,75 @@ def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None, prev_lab = self.tracked_video[frame_i-1] if rp is None: - self.rp = regionprops(lab.copy()) + self.rp = acdcRegionprops(lab.copy(), precache_centroids=False) else: self.rp = rp if prev_rp is None: - prev_rp = regionprops(prev_lab.copy()) - - IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, - prev_lab, - self.rp, - prev_rp, - IDs=IDs, - ) - self.aggr_track, self.mother_daughters = mother_daughter_assign(IoA_matrix, - IoA_thresh_daughter=self.IoA_thresh_daughter, - min_daughter=self.min_daughter, - max_daughter=self.max_daughter, - IoA_thresh_instant=self.IoA_thresh - ) - self.tracked_lab, IoA_matrix, self.assignments, _ = track_frame_base(prev_lab, - prev_rp, - lab, - self.rp, - IoA_thresh=self.IoA_thresh, - IoA_matrix=IoA_matrix, - aggr_track=self.aggr_track, - IoA_thresh_aggr=self.IoA_thresh_aggressive, - IDs_curr_untracked=self.IDs_curr_untracked, - IDs_prev=self.IDs_prev, - return_all=True, - mother_daughters=self.mother_daughters, - unique_ID=unique_ID - ) + prev_rp = acdcRegionprops(prev_lab.copy(), precache_centroids=False) + + full_IoA_matrix, full_curr_IDs, self.IDs_prev = calc_Io_matrix( + lab, + prev_lab, + self.rp, + prev_rp, + ) + IoA_matrix, self.IDs_curr_untracked, _ = calc_Io_matrix( + lab, + prev_lab, + self.rp, + prev_rp, + specific_IDs=specific_IDs, + ) + full_aggr_track, full_mother_daughters = mother_daughter_assign( + full_IoA_matrix, + IoA_thresh_daughter=self.IoA_thresh_daughter, + min_daughter=self.min_daughter, + max_daughter=self.max_daughter, + IoA_thresh_instant=self.IoA_thresh, + ) + + subset_idx_mapper = { + curr_ID: idx for idx, curr_ID in enumerate(self.IDs_curr_untracked) + } + self.aggr_track = [ + subset_idx_mapper[full_curr_IDs[idx]] + for idx in full_aggr_track + if full_curr_IDs[idx] in subset_idx_mapper + ] + self.mother_daughters = [] + for mother_idx, daughter_idxs in full_mother_daughters: + subset_daughter_idxs = [ + subset_idx_mapper[full_curr_IDs[idx]] + for idx in daughter_idxs + if full_curr_IDs[idx] in subset_idx_mapper + ] + if subset_daughter_idxs: + self.mother_daughters.append((mother_idx, subset_daughter_idxs)) - - self.tracked_video[frame_i] = self.tracked_lab + out = track_frame_base( + prev_lab, + prev_rp, + lab, + self.rp, + IoA_thresh=self.IoA_thresh, + IoA_matrix=IoA_matrix, + aggr_track=self.aggr_track, + IoA_thresh_aggr=self.IoA_thresh_aggressive, + IDs_curr_untracked=self.IDs_curr_untracked, + IDs_prev=self.IDs_prev, + return_all=True, + mother_daughters=self.mother_daughters, + unique_ID=unique_ID, + specific_IDs=specific_IDs, + return_assignments=return_assignments, + dont_return_tracked_lab=dont_return_tracked_lab, + ) + if dont_return_tracked_lab: + IoA_matrix, self.assignments, self.tracked_IDs = out + else: + self.tracked_lab, IoA_matrix, self.assignments, self.tracked_IDs = out + self.tracked_video[frame_i] = self.tracked_lab class normal_division_lineage_tree: """ @@ -594,8 +536,8 @@ def init_lineage_tree(self, lab=None, first_df=None, frame_i=None): if lab is not None: - rp = regionprops(lab) - labels = [obj.label for obj in rp] + rp = acdcRegionprops(lab, precache_centroids=False) + labels = rp.IDs cca_df = pd.DataFrame({ 'Cell_ID': labels, }) @@ -730,10 +672,10 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): None """ if rp is None: - rp = regionprops(lab) + rp = acdcRegionprops(lab, precache_centroids=False) if prev_rp is None: - prev_rp = regionprops(prev_lab) + prev_rp = acdcRegionprops(prev_lab, precache_centroids=False) IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, prev_lab, rp, prev_rp) @@ -751,7 +693,7 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None): self.mother_daughters = filtered_mother_daughters curr_IDs = set(self.IDs_curr_untracked) - prev_IDs = {obj.label for obj in prev_rp} + prev_IDs = set(prev_rp.IDs) new_IDs = curr_IDs - prev_IDs self.frames_for_dfs.add(frame_i) self.add_new_frame(frame_i, self.mother_daughters, self.IDs_prev, self.IDs_curr_untracked, None, curr_IDs, new_IDs) @@ -842,80 +784,6 @@ def update_df_li_locally(self, df, frame_i): df_data.loc[ID] = cell_row - # This will probably be made obsolete by the gui_mode version - # def insert_lineage_df(self, lineage_df, frame_i, update_fams=True, - # consider_children=True, raw_input=False, propagate=True, - # relevant_cells=None): - # """ - # Insert or replace a lineage DataFrame at a given frame index, optionally updating families and propagating changes. - - # Args: - # lineage_df (pd.DataFrame): The lineage DataFrame to insert. - # frame_i (int): The index of the frame. - # update_fams (bool, optional): If True, update families based on the changes. Defaults to True. - # consider_children (bool, optional): If True, update children of the inserted frame. Defaults to True. - - # Returns: - # None - # """ - # if not self.gui_mode: - # printl("here") - # if not raw_input: - # lineage_df = reorg_sister_cells_for_import(lineage_df) - # lineage_df = filter_cols(lineage_df) - - # lineage_df = checked_reset_index_Cell_ID(lineage_df) - # len_lineage_list = len(self.lineage_list) - # if frame_i == len_lineage_list: - # self.lineage_list.append(lineage_df) - # self.frames_for_dfs.add(frame_i) - - # self.update_df_li_locally(lineage_df, frame_i) - - # if propagate: - # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i, - # consider_children=consider_children, Cell_IDs_fixed=relevant_cells, - # families=self.families if update_fams else None) - # if update_fams: - # self.lineage_list, self.families = out - # else: - # self.lineage_list = out - - # elif frame_i < len_lineage_list: - # self.lineage_list[frame_i] = lineage_df - # self.update_df_li_locally(lineage_df, frame_i) - # if propagate: - # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i, - # consider_children=consider_children, Cell_IDs_fixed=relevant_cells, - # families=self.families if update_fams else None) - # if update_fams: - # self.lineage_list, self.families = out - # else: - # self.lineage_list = out - - - # elif frame_i > len_lineage_list: - # printl(f'WARNING: Frame_i {frame_i} was inserted. The lineage list was only {len(self.lineage_list)} frames long, so the last known lineage tree was copy pasted up to frame_i {frame_i}') - - # original_length = len(self.lineage_list) - # self.lineage_list = self.lineage_list + [self.lineage_list[-1]] * (frame_i - len(self.lineage_list)) - - # self.generate_gen_df_from_df_li(self.lineage_list, force=True) - - # self.lineage_list.append(lineage_df) - - # frame_is = set(range(len(self.lineage_list)-original_length)) - # self.frames_for_dfs = self.frames_for_dfs | frame_is - - # self.update_df_li_locally(lineage_df, frame_i) - # if propagate: - # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i, - # consider_children=consider_children, Cell_IDs_fixed=relevant_cells, - # families=self.families if update_fams else None) - # if update_fams: - # self.lineage_list, self.families = out - # else: - # self.lineage_list = out def _update_consistency(self, fixed_frame_i=None, fixed_df=None, Cell_IDs_fixed=None, consider_children=True): @@ -1043,63 +911,6 @@ def propagate(self, frame_i, relevant_cells=None): self._update_consistency(fixed_frame_i=frame_i, consider_children=True, Cell_IDs_fixed=relevant_cells) - # This will probably be made obsolete by the gui_mode version - # def load_lineage_df_list(self, df_li): - # """ - # Load a list of lineage DataFrames, reconstructing the lineage tree and families. - - # Args: - # df_li (list): List of acdc_df DataFrames. - - # Returns: - # None - # """ - # df_li = copy.deepcopy(df_li) # Ensure we don't modify the original list - # # Support for first_frame was removed since it is not necessary, just make the df_li correct... - # # Also the tree needs to be init before. Also if df_li does not contain any relevant dfs, nothing happens - # print('Loading lineage data...') - # df_li_new = [] - # families = [] - # families_root_IDs = [] - # added_IDs = set() - - # for i, df in enumerate(df_li): - # if df is None: - # continue - - # if 'generation_num_tree' not in df.columns: - # continue - - # mask = (df['generation_num_tree'].isnull() | - # df["generation_num_tree"].isna()) - - # if mask.any() or df["generation_num_tree"].empty: - # continue - - # df = checked_reset_index_Cell_ID(df) - - # df = filter_cols(df) - # df = reorg_sister_cells_for_import(df) - # self.frames_for_dfs.add(i) - # df_li_new.append(df) - - # df_filter = df.index.isin(added_IDs) - # for root_ID, group in df[df_filter].groupby('root_ID_tree'): - # if root_ID not in families_root_IDs: - # family = list(zip(group.index, group['generation_num_tree'])) - # families.append(family) - # families_root_IDs.append(root_ID) - # else: - # # If the root_ID is already in families, we just update the family with the new cells - # family_index = families_root_IDs.index(root_ID) - # families[family_index].extend(zip(group.index, group['generation_num_tree'])) - - # added_IDs.update(group.index) - - # if df_li_new: - # self.lineage_list = df_li_new - - # This will probably be made obsolete by the gui_mode version def export_df(self, frame_i): """ Export the lineage DataFrame for a specific frame, cleaning up auxiliary columns. @@ -1256,8 +1067,8 @@ def track(self, IoA_thresh_daughter=IoA_thresh_daughter ) pbar.update() - rp = regionprops(segm_video[0]) - prev_IDs = {obj.label for obj in rp} + rp = acdcRegionprops(segm_video[0], precache_centroids=False) + prev_IDs = rp.IDs_set prev_rp = rp continue @@ -1270,8 +1081,8 @@ def track(self, IDs_prev = tracker.IDs_prev assignments = tracker.assignments IDs_curr_untracked = tracker.IDs_curr_untracked - rp = regionprops(tracker.tracked_lab) - curr_IDs = {obj.label for obj in rp} + rp = acdcRegionprops(tracker.tracked_lab) + curr_IDs = rp.IDs_set new_IDs = curr_IDs - prev_IDs if record_lineage or return_tracked_lost_centroids: tree.add_new_frame( @@ -1289,7 +1100,7 @@ def track(self, found = True break if not found: - labels = [obj.label for obj in rp] + labels = rp.IDs printl(mother, mother_ID, IDs_curr_untracked, labels) raise ValueError('Something went wrong with the tracked lost centroids.') @@ -1328,6 +1139,9 @@ def track_frame(self, min_daughter:int = 2, max_daughter:int = 2, unique_ID: NotGUIParam =None, + return_assignments: NotGUIParam =False, + specific_IDs: NotGUIParam =None, + dont_return_tracked_lab: NotGUIParam =False, ): """ Tracks cell division in a single frame. (This is used for real time tracking in the GUI) @@ -1352,14 +1166,32 @@ def track_frame(self, segm_video = [previous_frame_labels, current_frame_labels] tracker = normal_division_tracker(segm_video, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh, IoA_thresh_aggressive) - tracker.track_frame(1, IDs=IDs, unique_ID=unique_ID) - tracked_video = tracker.tracked_video + tracker.track_frame( + 1, + IDs=IDs, + unique_ID=unique_ID, + return_assignments=return_assignments, + specific_IDs=specific_IDs, + dont_return_tracked_lab=dont_return_tracked_lab, + ) mother_daughters_pairs = tracker.mother_daughters IDs_prev = tracker.IDs_prev mothers = {IDs_prev[pair[0]] for pair in mother_daughters_pairs} + assignments = tracker.assignments + + if dont_return_tracked_lab: + return assignments + + tracked_lab = tracker.tracked_video[-1] + if not return_assignments: + return tracked_lab - return tracked_video[-1], mothers + add_info = { + 'mothers': mothers, + 'assignments': assignments + } + return tracked_lab, add_info def updateGuiProgressBar(self, signals): """ diff --git a/cellacdc/whitelist.py b/cellacdc/whitelist.py index 0621f4514..c2211b96e 100644 --- a/cellacdc/whitelist.py +++ b/cellacdc/whitelist.py @@ -1,7 +1,7 @@ import os import numpy as np import skimage.measure -from . import printl, myutils +from . import printl, myutils, regionprops import json from typing import Set, List, Tuple import time @@ -222,14 +222,14 @@ def create_new_centroids(self, new_IDs = self.originalLabsIDs[i] - self.originalLabsIDs[i-1] - rp = None if frame_i==i and curr_rp is not None: rp = curr_rp else: - rp = skimage.measure.regionprops(self.originalLabs[i]) + rp = regionprops.acdcRegionprops(self.originalLabs[i], + precache_centroids=False) self.new_centroids.append({ - tuple(map(int, obj.centroid)) for obj in rp if obj.label in new_IDs + tuple(map(int, rp.get_centroid(label))) for label in new_IDs }) @@ -411,7 +411,7 @@ def IDsAccepted(self, printl('Using curr_lab') IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} else: - IDs_curr = allData_li[frame_i]['IDs'] + IDs_curr = allData_li[frame_i]['regionprops'].IDs_set if self._debug: printl('Using allData_li') @@ -488,7 +488,7 @@ def makeOriginalLabsAndIDs(self, segm_data: np.ndarray, IDs = set(IDs_curr) elif allData_li is not None: try: - IDs = set(allData_li[i]['IDs']) + IDs = allData_li[i]['regionprops'].IDs_set except KeyError: pass if IDs is None: @@ -746,7 +746,7 @@ def propagateIDs(self, printl('Using index_lab_combo') IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} elif curr_rp is not None: - IDs_curr = {obj.label for obj in curr_rp} + IDs_curr = curr_rp.IDs_set if self._debug: printl('Using rp') elif curr_lab is not None: @@ -755,7 +755,7 @@ def propagateIDs(self, printl('Using curr_lab') IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)} else: - IDs_curr = allData_li[frame_i]['IDs'] + IDs_curr = allData_li[frame_i]['regionprops'].IDs_set if self._debug: printl('Using allData_li') @@ -872,7 +872,7 @@ def propagateIDs(self, if frame_i == i: IDs_curr_loc = IDs_curr else: - IDs_curr_loc = set(allData_li[i]['IDs']) + IDs_curr_loc =allData_li[i]['regionprops'].IDs_set new_whitelist = self.get(i, try_create_new_whitelists).copy() old_whitelist = new_whitelist.copy() @@ -939,12 +939,12 @@ def whitelistTrackOGagainstPreviousFrame_cb(self, signal_slot=None): if not self.whitelistCheckOriginalLabels(): return old_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] - prev_cell_IDs = posData.allData_li[frame_i-1]['IDs'] + prev_cell_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set self.whitelistTrackOGCurr(against_prev=True) new_cell_IDs = posData.whitelist.originalLabsIDs[frame_i] new_IDs = new_cell_IDs - old_cell_IDs - new_IDs = new_IDs & set(prev_cell_IDs) + new_IDs = new_IDs & prev_cell_IDs self.whitelistUpdateLab( track_og_curr=False, IDs_to_add=new_IDs, @@ -1066,7 +1066,7 @@ def whitelistViewOGIDs(self, checked:bool): self.store_data(autosave=False) if frame_i > 0: - missing_IDs = set(posData.IDs) - set(posData.allData_li[frame_i-1]['IDs']) + missing_IDs = posData.IDs_set - posData.allData_li[frame_i-1]['regionprops'].IDs_set self.trackManuallyAddedObject(missing_IDs,isNewID=True, wl_update=False) self.setAllTextAnnotations() @@ -1502,7 +1502,7 @@ def whitelistTrackOGCurr(self, frame_i:int=None, ### against what should I track? if lab is not None and not rp: - rp = skimage.measure.regionprops(lab) + rp = regionprops.acdcRegionprops(lab, precache_centroids=False) changed_frame = False if lab is None: @@ -1520,7 +1520,7 @@ def whitelistTrackOGCurr(self, frame_i:int=None, rp = posData.rp lab = posData.lab og_lab = posData.whitelist.originalLabs[frame_i] - og_rp = skimage.measure.regionprops(og_lab) + og_rp = regionprops.acdcRegionprops(og_lab, precache_centroids=False) # lab = lab.copy() denom_overlap_matrix = 'union' if not against_prev else 'area_prev' @@ -1530,7 +1530,6 @@ def whitelistTrackOGCurr(self, frame_i:int=None, denom_overlap_matrix=denom_overlap_matrix, posData = posData, setBrushID_func=self.setBrushID, - IDs=IDs, # assign_unique_new_IDs=False, ) @@ -1583,7 +1582,7 @@ def whitelistTrackCurrOG(self, frame_i:int=None, against_prev:bool=False): else: og_lab = posData.whitelist.originalLabs[frame_i] - og_rp = skimage.measure.regionprops(og_lab) + og_rp = regionprops.acdcRegionprops(og_lab, precache_centroids=False) denom_overlap_matrix = 'union' if not against_prev else 'area_prev' diff --git a/cellacdc/widgets.py b/cellacdc/widgets.py index b22e1cec4..bb247e504 100755 --- a/cellacdc/widgets.py +++ b/cellacdc/widgets.py @@ -71,6 +71,7 @@ from . import _core, core from . import QtScoped from . import prompts +from . import fonts from .acdc_regex import float_regex from .config import PREPROCESS_MAPPER from . import _base_widgets @@ -84,8 +85,7 @@ PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR = _palettes.QProgressBarHighlightedTextColor() TEXT_COLOR = _palettes.text_float_rgba() -font = QFont() -font.setPixelSize(12) +font = fonts.font custom_cmaps_filepath = os.path.join(settings_folderpath, 'custom_colormaps.ini') @@ -1075,7 +1075,7 @@ class _ReorderableListModel(QAbstractListModel): def __init__(self, items, parent=None): QAbstractItemModel.__init__(self, parent) - self.nodes = items + self.nodes = list(items) self.lastDroppedItems = [] self.pendingRemoveRowsAfterDrop = False @@ -1286,7 +1286,7 @@ def __init__( self.setStyleSheet(styleSheet) def setItems(self, items): - self._model.nodes = items + self._model.nodes = list(items) def items(self): return self._model.nodes @@ -1464,10 +1464,9 @@ def warnSelectionEmpty(self): def ok_cb(self, checked=False): self.clickedButton = self.sender() - self.cancel = False selectedItems = self.listBox.selectedItems() - self.selectedItemsText = [item.text() for item in selectedItems] - if not self.allowSingleSelection and len(self.selectedItemsText) < 2: + selectedItemsText = [item.text() for item in selectedItems] + if not self.allowSingleSelection and len(selectedItemsText) < 2: msg = myMessageBox(wrapText=False, showCentered=False) txt = html_utils.paragraph( 'You need to select two or more items.

' @@ -1477,9 +1476,12 @@ def ok_cb(self, checked=False): msg.warning(self, 'Select two or more items', txt) return - if not self.allowEmptySelection and not self.selectedItemsText: + if not self.allowEmptySelection and not selectedItemsText: self.warnSelectionEmpty() return + + self.cancel = False + self.selectedItemsText = selectedItemsText self.sigSelectionConfirmed.emit(self.selectedItemsText) self.close() @@ -12083,4 +12085,305 @@ def closeEvent(self, event): if self.screenShotWin is not None: self.screenShotWin.close() - return super().closeEvent(event) \ No newline at end of file + return super().closeEvent(event) + + +class MultiPickListWidget(QWidget): + """Generic list widget with multi-pick (repeated-selection) support. + + Each pickable row shows ``- count +`` controls. Left-clicking adds + one instance; right-clicking or Ctrl+left-click removes one. The same + item can appear multiple times in :attr:`selectionSequence`. + + Parameters + ---------- + items: + Initial list of item labels. + excludedItems: + Labels that are shown in the list but *not* given +/- controls + (e.g. placeholder entries like "Add custom model…"). Click events + on these are silently ignored. + parent: + Optional parent widget. + """ + + sigSelectionChanged = Signal(list) # emits selectionSequence on every change + + def __init__(self, items=None, excludedItems=None, parent=None): + super().__init__(parent) + + self._excludedItems = set(excludedItems or []) + self._itemsMap = {} # label → QListWidgetItem + self._countMap = defaultdict(int) + self._countLabelMap = {} + self.selectionSequence = [] + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + self.listBox = listWidget(isMultipleSelection=False) + self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + + for label in (items or []): + self._addListItem(label) + + if self._itemsMap: + self.listBox.setCurrentRow(0) + + self.listBox.itemClicked.connect(self._onItemClicked) + self.listBox.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.listBox.customContextMenuRequested.connect(self._onRightClick) + + # self.listBox.setStyleSheet(LISTWIDGET_STYLESHEET) + # self.setStyleSheet(LISTWIDGET_STYLESHEET) + layout.addWidget(self.listBox) + + + @property + def itemsMap(self): + """Dict mapping label → QListWidgetItem for all pickable items.""" + return dict(self._itemsMap) + + def currentItemName(self): + """Return the label of the currently highlighted item, or ``None``.""" + item = self.listBox.currentItem() + return item.text() if item is not None else None + + def addSelection(self, label): + """Add one instance of *label* to the selection.""" + if label not in self._itemsMap: + return + self.selectionSequence.append(label) + self._countMap[label] += 1 + self._updateCountLabel(label) + self.listBox.setCurrentItem(self._itemsMap[label]) + self.sigSelectionChanged.emit(list(self.selectionSequence)) + + def removeSelection(self, label): + """Remove the last instance of *label* from the selection.""" + if self._countMap.get(label, 0) <= 0: + return + for i in range(len(self.selectionSequence) - 1, -1, -1): + if self.selectionSequence[i] == label: + self.selectionSequence.pop(i) + break + self._countMap[label] = max(0, self._countMap[label] - 1) + self._updateCountLabel(label) + self.sigSelectionChanged.emit(list(self.selectionSequence)) + + def resetSelection(self): + """Clear all selections and reset all counters to zero.""" + self.selectionSequence = [] + self._countMap = defaultdict(int) + for label in self._countLabelMap: + self._updateCountLabel(label) + self.sigSelectionChanged.emit([]) + + def setSelectionFromList(self, labels): + """Set the selection to *labels* (duplicates supported).""" + self.resetSelection() + for label in labels: + self.addSelection(label) + + def registerItem(self, label, insertBeforeLabel=None): + """Dynamically add a new pickable item. + + Parameters + ---------- + label: + Text for the new item. + insertBeforeLabel: + If given, insert the new item immediately before this label. + If not found or not given, the item is appended. + + Returns the created ``QListWidgetItem``. + """ + if label in self._itemsMap: + return self._itemsMap[label] + + item = QListWidgetItem(label) + + if insertBeforeLabel is not None: + target = self._itemsMap.get(insertBeforeLabel) + if target is None: + for row in range(self.listBox.count()): + row_item = self.listBox.item(row) + if row_item is not None and row_item.text() == insertBeforeLabel: + target = row_item + break + if target is not None: + row = self.listBox.row(target) + self.listBox.insertItem(row, item) + else: + self.listBox.addItem(item) + else: + self.listBox.addItem(item) + + self._itemsMap[label] = item + self._addCounterWidget(label, item) + return item + + def _addListItem(self, label): + """Create a QListWidgetItem and, if pickable, attach a counter widget.""" + item = QListWidgetItem(label) + self.listBox.addItem(item) + if label not in self._excludedItems: + self._itemsMap[label] = item + self._addCounterWidget(label, item) + + def _addCounterWidget(self, label, item): + rowWidget = QWidget() + rowLayout = QHBoxLayout(rowWidget) + rowLayout.setContentsMargins(4, 0, 4, 0) + rowLayout.setSpacing(6) + + nameLabelPlaceholder = QSpacerItem(2, 0) + minusBtn = QPushButton('-') + plusBtn = QPushButton('+') + countLabel = QLabel(str(self._countMap.get(label, 0))) + + minusBtn.setFixedWidth(24) + plusBtn.setFixedWidth(24) + countLabel.setMinimumWidth(20) + countLabel.setAlignment(Qt.AlignCenter) + + minusBtn.clicked.connect(lambda _, lbl=label: self.removeSelection(lbl)) + plusBtn.clicked.connect(lambda _, lbl=label: self.addSelection(lbl)) + + rowLayout.addItem(nameLabelPlaceholder) + rowLayout.addStretch(1) + rowLayout.addWidget(minusBtn) + rowLayout.addWidget(countLabel) + rowLayout.addWidget(plusBtn) + + self._countLabelMap[label] = countLabel + self._updateCountLabel(label) + self.listBox.setItemWidget(item, rowWidget) + + def _updateCountLabel(self, label): + lbl = self._countLabelMap.get(label) + if lbl is not None: + count = self._countMap.get(label, 0) + lbl.setText(str(count)) + if count <= 0: + lbl.setStyleSheet('color: gray;') + else: + lbl.setStyleSheet('') + + def _onItemClicked(self, item): + label = item.text() + if label in self._excludedItems: + return + modifiers = QApplication.keyboardModifiers() + if modifiers & Qt.ControlModifier: + self.removeSelection(label) + else: + self.addSelection(label) + + def _onRightClick(self, pos): + item = self.listBox.itemAt(pos) + if item is None: + return + label = item.text() + if label in self._excludedItems: + return + self.removeSelection(label) + + +class ModelSelectionWidget(QWidget): + """List widget for selecting segmentation models. + + Thin wrapper around :class:`MultiPickListWidget` that populates the list + with the installed models and adds a special "Add custom model…" entry. + + ``sigSelectionChanged`` and ``selectionSequence`` are proxied from the + underlying :class:`MultiPickListWidget`. + """ + + _ADD_CUSTOM = 'Add custom model...' + + sigSelectionChanged = Signal(list) + + def __init__(self, parent=None, customFirst='', allowMultiSelection=False): + super().__init__(parent) + + self.allowMultiSelection = allowMultiSelection + + models = myutils.get_list_of_models() + if customFirst: + try: + models.insert(0, models.pop(models.index(customFirst))) + except ValueError: + pass + + items = models + [self._ADD_CUSTOM] + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + if allowMultiSelection: + items = models + self._picker = MultiPickListWidget( + items=items, + excludedItems=[self._ADD_CUSTOM], + parent=self, + ) + self._picker.listBox.setFont(font) + self._picker.sigSelectionChanged.connect(self.sigSelectionChanged) + self.listBox = self._picker.listBox + layout.addWidget(self._picker) + else: + self.listBox = listWidget(isMultipleSelection=False) + self.listBox.setFont(font) + self.listBox.addItems(models) + add_item = QListWidgetItem(self._ADD_CUSTOM) + add_item.setFont(fonts.italicFont) + self.listBox.addItem(add_item) + self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + self.listBox.setCurrentRow(0) + self._picker = None + layout.addWidget(self.listBox) + + # ------------------------------------------------------------------ + # Proxy helpers (multi-selection mode only) + # ------------------------------------------------------------------ + + @property + def selectionSequence(self): + return self._picker.selectionSequence if self._picker is not None else [] + + @property + def modelItemsMap(self): + return self._picker.itemsMap if self._picker is not None else {} + + def currentModelName(self): + if self._picker is not None: + return self._picker.currentItemName() + item = self.listBox.currentItem() + return item.text() if item is not None else None + + def addModelSelection(self, name): + if self._picker is not None: + self._picker.addSelection(name) + + def removeModelSelection(self, name): + if self._picker is not None: + self._picker.removeSelection(name) + + def resetSelectionSequence(self): + if self._picker is not None: + self._picker.resetSelection() + + def setSelectionFromList(self, models): + if self._picker is not None: + self._picker.setSelectionFromList(models) + + def registerCustomModel(self, model_name): + """Add a newly registered custom model and return its item.""" + if self._picker is not None: + return self._picker.registerItem( + model_name, insertBeforeLabel=self._ADD_CUSTOM + ) + item = QListWidgetItem(model_name) + self.listBox.insertItem(self.listBox.count() - 1, item) + return item diff --git a/cellacdc/workers.py b/cellacdc/workers.py index 25feb6f72..cb21ae732 100755 --- a/cellacdc/workers.py +++ b/cellacdc/workers.py @@ -43,7 +43,7 @@ from . import cli from .utils import resize from . import segm_utils - +from . import regionprops DEBUG = False def worker_exception_handler(func): @@ -178,7 +178,7 @@ def run(self): for frame_i, data_dict in enumerate(self.posData.allData_li): lab = data_dict['labels'] rp = data_dict['regionprops'] - IDs = data_dict['IDs'] + IDs = data_dict['regionprops'].IDs if lab is None: lab = self.posData.segm_data[frame_i] rp = skimage.measure.regionprops(lab) @@ -198,14 +198,10 @@ def run(self): class SegForLostIDsWorker(QObject): sigAskInit = Signal() - sigAskInstallModel = Signal(str) sigshowImageDebug = Signal(object) sigStoreData = Signal(bool) sigUpdateRP = Signal(bool, bool) - # sigGetData = Signal() - # sigGet2Dlab = Signal() - # sigGetTrackedLostIDs = Signal() - # sigGetBrushID = Signal() + sigGetSegForLostIDsInputImg = Signal(str) sigSegForLostIDsWorkerAskInstallGPU = Signal(str, bool) sigTrackManuallyAddedObject = Signal(object, object, bool, bool) @@ -217,6 +213,7 @@ def __init__(self, guiWin, mutex, waitCond, debug=False): self.mutex = mutex self.waitCond = waitCond self._debug = debug + self.inputImgForSegForLostIDs = None def emitSigAskInit(self): self.mutex.lock() @@ -224,35 +221,26 @@ def emitSigAskInit(self): self.waitCond.wait(self.mutex) self.mutex.unlock() - def emitSigShowImageDebug(self, img): - # self.mutex.lock() - self.sigshowImageDebug.emit(img) - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - def emitSigStoreData(self, autosave): self.mutex.lock() self.sigStoreData.emit(autosave) self.waitCond.wait(self.mutex) self.mutex.unlock() - def emitSigUpdateRP(self, wl_track_og_curr, wl_update): + def emitSigUpdateRP(self, wl_update, wl_track_og_curr): self.mutex.lock() - self.sigUpdateRP.emit(wl_track_og_curr, wl_update) + self.sigUpdateRP.emit(wl_update, wl_track_og_curr) self.waitCond.wait(self.mutex) self.mutex.unlock() - # def emitSigGetData(self): - # self.mutex.lock() - # self.sigGetData.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - def emitSigAskInstallModel(self, model_name): + def emitGetSegForLostIDsInputImg(self, image_channel_name): self.mutex.lock() - self.sigAskInstallModel.emit(model_name) + self.sigGetSegForLostIDsInputImg.emit(image_channel_name) self.waitCond.wait(self.mutex) + img = self.inputImgForSegForLostIDs + self.inputImgForSegForLostIDs = None self.mutex.unlock() + return img def emitSigAskInstallGPU(self, base_model_name, use_gpu): self.mutex.lock() @@ -261,24 +249,6 @@ def emitSigAskInstallGPU(self, base_model_name, use_gpu): self.waitCond.wait(self.mutex) self.mutex.unlock() - # def emitGet2Dlab(self): - # self.mutex.lock() - # self.sigGet2Dlab.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - # def emitGetTrackedLostIDs(self): - # self.mutex.lock() - # self.sigGetTrackedLostIDs.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - - # def emitGetBrushID(self): - # self.mutex.lock() - # self.sigGetBrushID.emit() - # self.waitCond.wait(self.mutex) - # self.mutex.unlock() - def emitTrackManuallyAddedObject(self, IDs, isLost, wl_update, wl_track_og_curr): self.mutex.lock() self.sigTrackManuallyAddedObject.emit(IDs, isLost, wl_update, wl_track_og_curr) @@ -298,36 +268,69 @@ def run(self): return self.logger.info('Segmentation for lost IDs started.') - model_name = 'local_seg' - base_model_name = self.guiWin.SegForLostIDsSettings['base_model_name'] - idx = self.guiWin.modelNames.index(model_name) - acdcSegment = self.guiWin.acdcSegment_li[idx] - - init_kwargs = self.guiWin.SegForLostIDsSettings['win'].init_kwargs - - use_gpu = init_kwargs.get('device_type', 'cpu') != 'cpu' - use_gpu = use_gpu or init_kwargs.get('use_gpu', False) - - self.emitSigAskInstallGPU(base_model_name, use_gpu) + + model_settings = self.guiWin.SegForLostIDsSettings['models_settings'] + + n_models = len(model_settings) + total_steps = 2 * n_models + self.signals.initProgressBar.emit(total_steps) - if not self.gpu_go: - self.signals.finished.emit(self) - return + assigned_IDs = [] + missing_IDs_global = set() + original_lab = posData.lab.copy() + IDs_bboxs_list = [] + bboxs_list = [] + new_labs = [] - if not self.dont_force_cpu: - if 'device' in init_kwargs: - init_kwargs['device'] = 'cpu' - if 'use_gpu' in init_kwargs: - init_kwargs['use_gpu'] = False + prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels']) + prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set + + # iteratively go through models, keeping found labs and only feeding remaining lost IDs to the next model + for model_idx, model_settings_i in enumerate(model_settings): + base_model_name = model_settings_i['base_model_name'] + init_kwargs_new = dict(model_settings_i['init_kwargs_new']) + image_channel_name = init_kwargs_new.pop( + 'image_channel_name', 'Displayed image' + ) + args_new = model_settings_i['args_new'] + init_kwargs = model_settings_i.get('init_kwargs', {}) + model_kwargs = model_settings_i.get('model_kwargs', {}) + preproc_recipe = model_settings_i.get('preproc_recipe', None) + applyPostProcessing = model_settings_i.get('applyPostProcessing', False) + standardPostProcessKwargs = model_settings_i.get('standardPostProcessKwargs', {}) + customPostProcessFeatures = model_settings_i.get('customPostProcessFeatures', None) + customPostProcessGroupedFeatures = model_settings_i.get('customPostProcessGroupedFeatures', None) + + # Fall back to reading from the live win object when available + win = model_settings_i.get('win') + if win is not None: + init_kwargs = win.init_kwargs + model_kwargs = win.model_kwargs + preproc_recipe = win.preproc_recipe + applyPostProcessing = win.applyPostProcessing + standardPostProcessKwargs = win.standardPostProcessKwargs + customPostProcessFeatures = win.customPostProcessFeatures + customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures + + use_gpu = init_kwargs.get('device_type', 'cpu').lower() != 'cpu' + use_gpu = use_gpu or init_kwargs.get('use_gpu', False) + + self.emitSigAskInstallGPU(base_model_name, use_gpu) + + if not self.gpu_go: + self.signals.finished.emit(self) + return + + if not self.dont_force_cpu: + if 'device' in init_kwargs: + init_kwargs_new = dict(init_kwargs_new, device='cpu') + if 'use_gpu' in init_kwargs: + init_kwargs_new = dict(init_kwargs_new, use_gpu=False) - if acdcSegment is None or base_model_name != self.guiWin.local_seg_base_model_name: try: self.logger.info(f'Importing {base_model_name}...') - self.emitSigAskInstallModel(base_model_name) acdcSegment = myutils.import_segment_module(base_model_name) - self.guiWin.acdcSegment_li[idx] = acdcSegment - self.guiWin.local_seg_base_model_name = base_model_name - except (IndexError, ImportError, KeyError) as e: + except (IndexError, ImportError, KeyError): self.logger.warning( f'Cannot import {base_model_name} model. ' 'Please install it first.' @@ -339,134 +342,165 @@ def run(self): self.signals.finished.emit(self) return - win = self.guiWin.SegForLostIDsSettings['win'] - init_kwargs_new = self.guiWin.SegForLostIDsSettings['init_kwargs_new'] - args_new = self.guiWin.SegForLostIDsSettings['args_new'] - - model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new) - if model is None: - self.logger.info('Segmentation model was not initialized correctly!') - self.signals.critical.emit( - (self, 'Segmentation model was not initialized correctly!') - ) - self.signals.finished.emit(self) - return - if self._debug: - try: - model.setupLogger(self.guiwin.logger) - except Exception as e: - pass - - assigned_IDs = [] - missing_IDs_global = set() - original_lab = posData.lab.copy() - IDs_bboxs_list = [] - bboxs_list = [] + model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new) + if model is None: + self.logger.info('Segmentation model was not initialized correctly!') + self.signals.critical.emit( + (self, 'Segmentation model was not initialized correctly!') + ) + self.signals.finished.emit(self) + return + if self._debug: + try: + model.setupLogger(self.guiWin.logger) + except Exception: + pass - curr_img = self.guiWin.getDisplayedImg1() - prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels']) - prev_IDs = set(posData.allData_li[frame_i-1]['IDs']) + curr_img = self.emitGetSegForLostIDsInputImg(image_channel_name) + if curr_img is None: + self.signals.critical.emit( + (self, 'Could not get input image for SegForLostIDsWorker') + ) + self.signals.finished.emit(self) + return - # should probably not paly so much with posData.lab, instead handle stuff myself - self.signals.initProgressBar.emit(2 * args_new['max_iterations']) - new_labs = np.zeros([args_new['max_iterations'], *posData.lab.shape], dtype=np.uint32) - for i in range(args_new['max_iterations']): curr_lab = self.guiWin.get_2Dlab(posData.lab) tracked_lost_IDs = self.guiWin.getTrackedLostIDs() new_unique_ID = self.guiWin.setBrushID(useCurrentLab=True, return_val=True) - missing_IDs = prev_IDs - set(posData.IDs) - set(tracked_lost_IDs) + missing_IDs = prev_IDs - posData.rp.IDs_set - set(tracked_lost_IDs) missing_IDs_global.update(missing_IDs) assigned_IDs_prev = assigned_IDs.copy() out = segm_utils.single_cell_seg( - model, prev_lab, curr_lab, curr_img, + model, prev_lab, curr_lab, curr_img, missing_IDs, new_unique_ID, - win, posData, + posData, distance_filler_growth=args_new['distance_filler_growth'], overlap_threshold=args_new['overlap_threshold'], padding=args_new['padding'], + model_kwargs=model_kwargs, + preproc_recipe=preproc_recipe, + applyPostProcessing=applyPostProcessing, + standardPostProcessKwargs=standardPostProcessKwargs, + customPostProcessFeatures=customPostProcessFeatures, + customPostProcessGroupedFeatures=customPostProcessGroupedFeatures, + debug=self._debug ) - new_lab, assigned_IDs, IDs_bboxs, bboxs = out - + if self._debug: + new_lab, assigned_IDs, IDs_bboxs, bboxs, imgs_to_show = out + else: + new_lab, assigned_IDs, IDs_bboxs, bboxs = out + IDs_bboxs_list.append(IDs_bboxs) bboxs_list.append(bboxs) posData.lab = new_lab self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) newly_assigned_IDs = set(assigned_IDs) - set(assigned_IDs_prev) self.emitTrackManuallyAddedObject(newly_assigned_IDs, True, False, False) - new_labs[i] = posData.lab.copy() + new_labs.append(posData.lab.copy()) self.signals.progressBar.emit(1) - if self._debug: - originals = [] - models = [] - - posData.lab = original_lab.copy() - - global_area_mean = np.mean([obj.area for obj in posData.rp]) - for IDs_bboxs, bboxs in zip(IDs_bboxs_list, bboxs_list): - model_lab = new_labs[i] if self._debug: - originals.append(original_lab.copy()) - models.append(posData.lab.copy()) - - for IDs, bbox in zip(IDs_bboxs, bboxs): + print(f'Model {model_idx}:') + print('Displaying curr_img and curr_lab:') + display_info = { + 'title': f'Model {model_idx}, input image and lab', + 'images': [curr_img, curr_lab], + 'img_titles': ['curr_img', 'curr_lab'] + } + self.sigshowImageDebug.emit(display_info) + print('box_curr_img, box_curr_lab, box_curr_lab_other_IDs_grown, box_curr_img (after filling), box_model_lab') + for i, imgs in imgs_to_show.items(): + display_info = { + 'title': f'Model {model_idx}, bbox {i}', + 'images': imgs, + 'img_titles': [ + 'box_curr_img', 'box_curr_lab', 'box_curr_lab_other_IDs_grown', + 'box_curr_img (after filling)', 'box_model_lab' + ] + } + self.sigshowImageDebug.emit(display_info) + global_areas = [obj.area for obj in posData.rp] + global_area_mean = np.mean(global_areas) if len(global_areas) > 0 else None + for i, (IDs_bboxs, bboxs) in enumerate(zip(IDs_bboxs_list, bboxs_list)): + args_new = model_settings[i]['args_new'] + model_lab = new_labs[i] + + for j, (IDs, bbox) in enumerate(zip(IDs_bboxs, bboxs)): + box_x_min, box_x_max, box_y_min, box_y_max = bbox + + model_bbox_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max] + model_bbox_lab = skimage.segmentation.clear_border(model_bbox_lab, buffer_size=1) + model_lab_rp = regionprops.acdcRegionprops(model_bbox_lab, precache_centroids=False) + + if self._debug: + IDs_filtered_border = [ + ID for ID in IDs + if ID not in model_lab_rp.IDs_set + ] - box_x_min, box_x_max, box_y_min, box_y_max = bbox original_bbox_lab = original_lab[box_x_min:box_x_max, box_y_min:box_y_max] original_bbox_lab_cleared_borders = skimage.segmentation.clear_border(original_bbox_lab) - box_model_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max] + original_lab_rp = regionprops.acdcRegionprops(original_bbox_lab_cleared_borders, precache_centroids=False) - # original_bbox_lab[np.isin(original_bbox_lab, IDs)] = 0 should be a given. If not seg for lost IDs this recommended - - box_model_lab = skimage.segmentation.clear_border(box_model_lab, buffer_size=1) - - rp_model_lab = skimage.measure.regionprops(box_model_lab) - rp_original_lab = skimage.measure.regionprops(original_bbox_lab) - rp_original_lab_cleared = skimage.measure.regionprops(original_bbox_lab_cleared_borders) - - original_IDs = [obj.label for obj in rp_original_lab] - areas = [obj.area for obj in rp_original_lab_cleared] + areas = [obj.area for obj in original_lab_rp] if len(areas) > 0: area_mean = np.mean(areas) - else: + elif global_area_mean is not None: area_mean = global_area_mean + else: + model_areas = [obj.area for obj in model_lab_rp] + area_mean = np.mean(model_areas) if len(model_areas) > 0 else None + + skip_size_filter = area_mean is None + if not skip_size_filter: + min_area = (1 - args_new['size_perc_diff']) * area_mean + max_area = (1 + args_new['size_perc_diff']) * area_mean + if args_new['allow_only_tracked_cells']: - filtered_IDs = [obj.label for obj in rp_model_lab - if obj.area > (1 - args_new['size_perc_diff']) * area_mean - and obj.area < (1 + args_new['size_perc_diff']) * area_mean - and obj.label not in original_IDs - and obj.label in missing_IDs_global] + filtered_IDs = [ + obj.label for obj in model_lab_rp + if (skip_size_filter or (obj.area > min_area and obj.area < max_area)) + and obj.label in prev_IDs # only keep objects that have ID already in the previous frame + ] + else: - filtered_IDs = [obj.label for obj in rp_model_lab - if obj.area > (1 - args_new['size_perc_diff']) * area_mean - and obj.area < (1 + args_new['size_perc_diff']) * area_mean - and obj.label not in original_IDs] - - if self._debug or DEBUG: - filtered_sizes = [(obj.label, obj.area) for obj in rp_model_lab if obj.label in filtered_IDs] - self.logger.info(f"Filtered sizes: {filtered_sizes}") - for label in filtered_IDs: - original_bbox_lab[box_model_lab == label] = label # here the stuff should be tracked, so we keep the ID! - - # original_lab[box_x_min:box_x_max, box_y_min:box_y_max] = original_bbox_lab - - self.signals.progressBar.emit(1) - - posData.lab = original_lab + filtered_IDs = [ + obj.label for obj in model_lab_rp + if (skip_size_filter or (obj.area > min_area and obj.area < max_area)) + ] + + if self._debug: + if args_new['allow_only_tracked_cells']: + IDs_filtered_for_size = [ + obj.label for obj in model_lab_rp + if (skip_size_filter or (obj.area > min_area and obj.area < max_area)) + ] + else: + IDs_filtered_for_size = [] + IDs_filtered_for_tracking = [ + obj.label for obj in model_lab_rp + if obj.label in prev_IDs + ] + print(f'Model {i}, bbox {j}:') + print(f' Start: {[obj.label for obj in model_lab_rp]}') + print(f' Size: {IDs_filtered_for_size}') + print(f' Tracking: {IDs_filtered_for_tracking}') + print(f' Border: {IDs_filtered_border}') + + original_bbox_lab[ + np.isin(model_bbox_lab, filtered_IDs) + ] = model_bbox_lab[np.isin(model_bbox_lab, filtered_IDs)] - # if self._debug: - # originals = np.concatenate(originals, axis=0) - # models = np.concatenate(models, axis=0) - # self.emitSigShowImageDebug(originals) - # self.emitSigShowImageDebug(models) + self.signals.progressBar.emit(1) - self.emitSigUpdateRP(wl_track_og_curr=True, wl_update=True) + posData.lab = original_lab + self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) self.emitSigStoreData(autosave=True) self.logger.info('Segmentation for lost IDs done.') - + self.signals.finished.emit(self) class AlignDataWorker(QObject): @@ -5185,7 +5219,7 @@ def check(self, posData): # There are no annotations at frame_i --> stop break - IDs = data_dict['IDs'] + IDs = data_dict['regionprops'].IDs checker = core.CcaIntegrityChecker(cca_df, lab, IDs) for checkpoint in checkpoints: @@ -6220,7 +6254,7 @@ def saveAcdcDf(self, posData: load.loadData, end_i): last_cca_frame_i=self.mainWin.save_cca_until_frame_i ) - def saveSegmData(self, posData, end_i, saved_segm_data): + def saveSegmData(self, posData: load.loadData, end_i, saved_segm_data): self.progress.emit(f'Saving segmentation data for {posData.relPath}...') @@ -6263,6 +6297,13 @@ def saveSegmData(self, posData, end_i, saved_segm_data): io.savez_compressed( posData.segm_npz_path, np.squeeze(saved_segm_data) ) + + # save information about the segmention + posData.updateSegmMetadata(all=True) + posData.saveSegmMetadataIni() + + # save rp info about segm + self.progress.emit(f'Saving additional data for {posData.relPath}...') posData.segm_data = saved_segm_data # Allow single 2D/3D image if posData.SizeT == 1: diff --git a/precompile_functions.py b/precompile_functions.py new file mode 100644 index 000000000..ab9abd49b --- /dev/null +++ b/precompile_functions.py @@ -0,0 +1,28 @@ +# only needed for cython extensions, not needed to run normally +import sys +from setuptools import setup, Extension +from Cython.Build import cythonize +import numpy as np + +setup( + ext_modules=cythonize( + Extension( + "cellacdc.precompiled.precompiled_functions", + sources=["cellacdc/precompiled_functions.pyx"], + include_dirs=[np.get_include()], + ), + annotate=True, + build_dir="build/cython", # .c and .html files go here + ) +) +# # move compiled binary to precompiled/ +# import shutil +# import os + +# src_dir = "cellacdc" +# for filename in os.listdir(src_dir): +# if filename.startswith("precompiled_functions") and (filename.endswith(".so") or filename.endswith(".pyd")): +# target_path = os.path.join("cellacdc", "precompiled", filename) +# shutil.move(os.path.join(src_dir, filename), target_path) +# print(f"Moved {filename} to {target_path}") + diff --git a/pyproject.toml b/pyproject.toml index e96d7a741..14799461b 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,9 @@ requires = [ "setuptools>=64", "wheel", - "setuptools_scm[toml]>=8" + "setuptools_scm[toml]>=8", + "cython", + "numpy", ] build-backend = "setuptools.build_meta" diff --git a/tests/test_cellacdc_tracker_specific_ids.py b/tests/test_cellacdc_tracker_specific_ids.py new file mode 100644 index 000000000..130e1f8bb --- /dev/null +++ b/tests/test_cellacdc_tracker_specific_ids.py @@ -0,0 +1,160 @@ +import numpy as np +from skimage.measure import regionprops + +from cellacdc.trackers.CellACDC import CellACDC_tracker +from cellacdc.trackers.CellACDC_2steps.CellACDC_2steps_tracker import tracker as TwoStepsTracker +from cellacdc.trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import tracker as NormalDivisionTracker + + +def test_track_frame_specific_ids_only_tracks_requested_current_ids(): + prev_lab = np.array( + [ + [1, 1, 0, 5, 5], + [1, 1, 0, 5, 5], + ], + dtype=np.uint16, + ) + lab = np.array( + [ + [7, 7, 0, 5, 5], + [7, 7, 0, 5, 5], + ], + dtype=np.uint16, + ) + + tracked_lab, assignments = CellACDC_tracker.track_frame( + prev_lab, + regionprops(prev_lab), + lab, + regionprops(lab), + IDs_curr_untracked=[7, 5], + unique_ID=10, + assign_unique_new_IDs=True, + return_assignments=True, + specific_IDs=[5], + ) + + np.testing.assert_array_equal(tracked_lab, lab) + assert assignments['assignments'] == {} + + +def test_track_frame_specific_ids_skips_merging_with_unrelated_current_labels(): + prev_lab = np.array( + [ + [5, 5, 0, 0], + [5, 5, 0, 0], + ], + dtype=np.uint16, + ) + lab = np.array( + [ + [7, 7, 0, 5], + [7, 7, 0, 5], + ], + dtype=np.uint16, + ) + + tracked_lab, add_info = CellACDC_tracker.track_frame( + prev_lab, + regionprops(prev_lab), + lab, + regionprops(lab), + IDs_curr_untracked=[7, 5], + unique_ID=10, + assign_unique_new_IDs=True, + return_assignments=True, + specific_IDs=[7], + ) + + expected = np.array( + [ + [10, 10, 0, 5], + [10, 10, 0, 5], + ], + dtype=np.uint16, + ) + + np.testing.assert_array_equal(tracked_lab, expected) + assert add_info['assignments'] == {7: 10} + + +def test_two_steps_specific_ids_can_match_selected_new_object_to_lost_previous_id(): + prev_lab = np.array( + [ + [5, 5, 0, 0], + [5, 5, 0, 0], + ], + dtype=np.uint16, + ) + lab = np.array( + [ + [7, 7, 0, 0], + [7, 7, 0, 0], + ], + dtype=np.uint16, + ) + + tracked_lab, add_info = TwoStepsTracker( + annotate_objects_tracked_second_step=False + ).track_frame( + prev_lab, + lab, + overlap_threshold=0.4, + lost_IDs_search_range=10, + unique_ID=10, + return_assignments=True, + specific_IDs=[7], + ) + + expected = np.array( + [ + [5, 5, 0, 0], + [5, 5, 0, 0], + ], + dtype=np.uint16, + ) + + np.testing.assert_array_equal(tracked_lab, expected) + assert add_info['assignments'] == {7: 5} + + +def test_normal_division_specific_ids_preserve_division_context(): + prev_lab = np.array( + [ + [5, 5, 5, 5], + [5, 5, 5, 5], + ], + dtype=np.uint16, + ) + lab = np.array( + [ + [7, 7, 8, 8], + [7, 7, 8, 8], + ], + dtype=np.uint16, + ) + + tracked_lab, add_info = NormalDivisionTracker().track_frame( + prev_lab, + lab, + IoA_thresh=0.8, + IoA_thresh_daughter=0.25, + IoA_thresh_aggressive=0.5, + min_daughter=2, + max_daughter=2, + unique_ID=20, + return_assignments=True, + specific_IDs=[7], + ) + + expected = np.array( + [ + [20, 20, 8, 8], + [20, 20, 8, 8], + ], + dtype=np.uint16, + ) + + np.testing.assert_array_equal(tracked_lab, expected) + assert add_info['mothers'] == {5} + assert add_info['assignments'] == {7: 20} \ No newline at end of file diff --git a/tests/test_precompiled_functions.py b/tests/test_precompiled_functions.py new file mode 100644 index 000000000..0515f6ce4 --- /dev/null +++ b/tests/test_precompiled_functions.py @@ -0,0 +1,520 @@ +"""Tests for Cython functions in cellacdc/precompiled_functions.pyx. + +Each Cython function is validated against a pure-Python / skimage reference +implementation on realistic synthetic label images built from filled discs +(2-D) and filled spheres (3-D). + +Run with: + pytest tests/test_precompiled_functions.py -v +""" + +import pytest +import numpy as np +from skimage.draw import disk, ellipsoid +from skimage.measure import regionprops + +# --------------------------------------------------------------------------- +# Skip the whole module when the Cython extension is not compiled yet +# --------------------------------------------------------------------------- +pytest.importorskip( + "cellacdc.precompiled.precompiled_functions", + reason="Cython extension not compiled; run: python precompile_functions.py build_ext --inplace", +) +from cellacdc.precompiled.precompiled_functions import ( + find_all_objects_2D, + find_all_objects_3D, + most_common_projection_3D, + object_projections_and_size_3D, + object_projection_and_size_3D, + calc_IoA_matrix_2D, + calc_IoA_matrix_3D, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_disc_label_image(shape, specs): + """Build a 2-D uint32 label image from a list of (label, cy, cx, radius).""" + img = np.zeros(shape, dtype=np.uint32) + for label, cy, cx, r in specs: + rr, cc = disk((cy, cx), r, shape=shape) + img[rr, cc] = label + return img + + +def _make_sphere_label_image(shape, specs): + """Build a 3-D uint32 label image from (label, cz, cy, cx, rz, ry, rx).""" + img = np.zeros(shape, dtype=np.uint32) + for label, cz, cy, cx, rz, ry, rx in specs: + sph = ellipsoid(rz, ry, rx) + sz, sy, sx = sph.shape + z0 = cz - sz // 2 + y0 = cy - sy // 2 + x0 = cx - sx // 2 + for dz in range(sz): + for dy in range(sy): + for dx in range(sx): + if not sph[dz, dy, dx]: + continue + zi, yi, xi = z0 + dz, y0 + dy, x0 + dx + if 0 <= zi < shape[0] and 0 <= yi < shape[1] and 0 <= xi < shape[2]: + img[zi, yi, xi] = label + return img + + +def _reference_bboxes_2D(label_img): + """skimage regionprops bounding boxes as {label: (r0, r1, c0, c1)}.""" + result = {} + for obj in regionprops(label_img.astype(np.int32)): + r0, c0, r1, c1 = obj.bbox + result[obj.label] = (r0, r1, c0, c1) + return result + + +def _reference_bboxes_3D(label_img): + """skimage regionprops bounding boxes as {label: (z0, z1, r0, r1, c0, c1)}.""" + result = {} + for obj in regionprops(label_img.astype(np.int32)): + z0, r0, c0, z1, r1, c1 = obj.bbox + result[obj.label] = (z0, z1, r0, r1, c0, c1) + return result + + +def _python_ioa_matrix(lab, prev_lab, rp, prev_rp, use_union): + """Pure-Python IoA matrix (the original fallback in CellACDC_tracker).""" + IDs_curr = [obj.label for obj in rp] + IDs_prev = [obj.label for obj in prev_rp] + IoA = np.zeros((len(IDs_curr), len(IDs_prev)), dtype=np.float64) + rp_mapper = {obj.label: obj for obj in rp} + idx_curr = {ID: i for i, ID in enumerate(IDs_curr)} + for j, obj_prev in enumerate(prev_rp): + if use_union: + pass # denom computed per overlap + else: + denom_val = obj_prev.area + intersect_IDs, intersects = np.unique( + lab[obj_prev.slice][obj_prev.image], return_counts=True + ) + for intersect_ID, I in zip(intersect_IDs, intersects): + if intersect_ID == 0 or I == 0: + continue + if use_union: + if intersect_ID not in rp_mapper: + continue + obj_curr = rp_mapper[intersect_ID] + denom_val = obj_prev.area + obj_curr.area - I + if denom_val == 0: + continue + idx = idx_curr.get(intersect_ID) + if idx is None: + continue + IoA[idx, j] = I / denom_val + return IoA, IDs_curr, IDs_prev + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +DISC_SPECS = [ + # (label, cy, cx, radius) + (1, 30, 30, 20), # large disc, centre-left + (2, 30, 90, 15), # medium disc, centre-right + (3, 80, 60, 10), # smaller disc, bottom + (4, 80, 110, 12), # slightly overlapping with disc 3 at boundary +] + +DISC_SPECS_SHIFTED = [ + # same cells shifted a few pixels to simulate motion + (1, 32, 32, 20), + (2, 28, 92, 15), + (3, 82, 62, 10), + (4, 78, 112, 12), +] + +SPHERE_SPECS = [ + # (label, cz, cy, cx, rz, ry, rx) + (1, 10, 20, 20, 5, 8, 8), + (2, 10, 20, 50, 4, 6, 6), + (3, 10, 45, 35, 3, 5, 5), +] + +SPHERE_SPECS_SHIFTED = [ + (1, 11, 21, 21, 5, 8, 8), + (2, 9, 19, 52, 4, 6, 6), + (3, 10, 46, 34, 3, 5, 5), +] + + +@pytest.fixture(scope="module") +def label_2d(): + return _make_disc_label_image((128, 128), DISC_SPECS) + + +@pytest.fixture(scope="module") +def label_2d_shifted(): + return _make_disc_label_image((128, 128), DISC_SPECS_SHIFTED) + + +@pytest.fixture(scope="module") +def label_3d(): + return _make_sphere_label_image((20, 64, 64), SPHERE_SPECS) + + +@pytest.fixture(scope="module") +def label_3d_shifted(): + return _make_sphere_label_image((20, 64, 64), SPHERE_SPECS_SHIFTED) + + +# --------------------------------------------------------------------------- +# find_all_objects_2D +# --------------------------------------------------------------------------- + +class TestFindAllObjects2D: + def test_labels_match(self, label_2d): + labels, _ = find_all_objects_2D(label_2d) + assert set(labels.tolist()) == {1, 2, 3, 4} + + def test_bbox_matches_skimage(self, label_2d): + labels, bboxes = find_all_objects_2D(label_2d) + ref = _reference_bboxes_2D(label_2d) + for lbl, bbox in zip(labels.tolist(), bboxes.tolist()): + r0, r1, c0, c1 = bbox + assert (r0, r1, c0, c1) == ref[lbl], ( + f"Label {lbl}: got {(r0,r1,c0,c1)}, expected {ref[lbl]}" + ) + + def test_empty_image_returns_empty(self): + empty = np.zeros((64, 64), dtype=np.uint32) + result = find_all_objects_2D(empty) + assert result == ([], []) + + def test_single_pixel_object(self): + img = np.zeros((10, 10), dtype=np.uint32) + img[5, 7] = 1 + labels, bboxes = find_all_objects_2D(img) + assert labels[0] == 1 + assert list(bboxes[0]) == [5, 6, 7, 8] + + def test_label_above_300_triggers_growth(self): + """Label > 300 must still produce the correct bounding box.""" + img = np.zeros((64, 64), dtype=np.uint32) + rr, cc = disk((32, 32), 10, shape=img.shape) + img[rr, cc] = 350 + labels, bboxes = find_all_objects_2D(img) + ref = _reference_bboxes_2D(img) + r0, r1, c0, c1 = bboxes[0].tolist() + assert (r0, r1, c0, c1) == ref[350] + + def test_label_above_600_triggers_second_growth(self): + img = np.zeros((64, 64), dtype=np.uint32) + rr, cc = disk((32, 32), 10, shape=img.shape) + img[rr, cc] = 650 + labels, bboxes = find_all_objects_2D(img) + ref = _reference_bboxes_2D(img) + r0, r1, c0, c1 = bboxes[0].tolist() + assert (r0, r1, c0, c1) == ref[650] + + def test_multiple_labels_across_300_boundary(self): + img = np.zeros((64, 64), dtype=np.uint32) + rr1, cc1 = disk((20, 20), 8, shape=img.shape) + rr2, cc2 = disk((20, 50), 8, shape=img.shape) + img[rr1, cc1] = 1 + img[rr2, cc2] = 301 + labels, bboxes = find_all_objects_2D(img) + ref = _reference_bboxes_2D(img) + for lbl, bbox in zip(labels.tolist(), bboxes.tolist()): + assert tuple(bbox) == ref[lbl] + + def test_bbox_dtype_is_uint32(self, label_2d): + _, bboxes = find_all_objects_2D(label_2d) + assert bboxes.dtype == np.uint32 + + +# --------------------------------------------------------------------------- +# find_all_objects_3D +# --------------------------------------------------------------------------- + +class TestFindAllObjects3D: + def test_labels_match(self, label_3d): + labels, _ = find_all_objects_3D(label_3d) + assert set(labels.tolist()) == {1, 2, 3} + + def test_bbox_matches_skimage(self, label_3d): + labels, bboxes = find_all_objects_3D(label_3d) + ref = _reference_bboxes_3D(label_3d) + for lbl, bbox in zip(labels.tolist(), bboxes.tolist()): + z0, z1, r0, r1, c0, c1 = bbox + assert (z0, z1, r0, r1, c0, c1) == ref[lbl], ( + f"Label {lbl}: got {(z0,z1,r0,r1,c0,c1)}, expected {ref[lbl]}" + ) + + def test_empty_image_returns_empty(self): + empty = np.zeros((8, 16, 16), dtype=np.uint32) + result = find_all_objects_3D(empty) + assert result == ([], []) + + def test_single_voxel_object(self): + img = np.zeros((8, 8, 8), dtype=np.uint32) + img[3, 4, 5] = 2 + labels, bboxes = find_all_objects_3D(img) + assert labels[0] == 2 + assert list(bboxes[0]) == [3, 4, 4, 5, 5, 6] + + def test_bbox_dtype_is_uint32(self, label_3d): + _, bboxes = find_all_objects_3D(label_3d) + assert bboxes.dtype == np.uint32 + + +# --------------------------------------------------------------------------- +# most_common_projection_3D +# --------------------------------------------------------------------------- + +class TestMostCommonProjection3D: + @pytest.mark.parametrize( + "axis, expected", + [ + ( + 0, + np.array( + [ + [1, 1], + [1, 2], + [2, 2], + [2, 2], + [1, 0], + [1, 0], + ], + dtype=np.uint32, + ), + ), + (1, np.array([[1, 2]], dtype=np.uint32)), + (2, np.array([[1, 1, 2, 2, 1, 1]], dtype=np.uint32)), + ], + ) + def test_counts_across_full_axis_not_runs(self, axis, expected): + """A label split into multiple runs must still be counted globally.""" + lab = np.array( + [ + [ + [1, 1], + [1, 2], + [2, 2], + [2, 2], + [1, 0], + [1, 0], + ] + ], + dtype=np.uint32, + ) + out = most_common_projection_3D(lab, axis) + assert out.shape == expected.shape + np.testing.assert_array_equal(out, expected) + + @pytest.mark.parametrize("axis", [0, 1, 2]) + def test_ignores_zero_but_returns_zero_if_no_nonzero(self, axis): + lab = np.zeros((3, 4, 5), dtype=np.uint32) + out = most_common_projection_3D(lab, axis) + assert out.dtype == np.uint32 + np.testing.assert_array_equal(out, np.zeros_like(out, dtype=np.uint32)) + + def test_invalid_axis_raises(self): + lab = np.zeros((2, 2, 2), dtype=np.uint32) + with pytest.raises(ValueError): + most_common_projection_3D(lab, 3) + + +# --------------------------------------------------------------------------- +# object_projections_and_size_3D +# --------------------------------------------------------------------------- + +class TestObjectProjectionsAndSize3D: + def test_matches_numpy_binary_projections_and_voxel_count_for_one_id(self): + cutout = np.zeros((4, 5, 6), dtype=np.uint32) + cutout[0, 1, 2] = 5 + cutout[1, 1, 2] = 5 + cutout[2, 3, 4] = 5 + cutout[3, 0, 1] = 9 + cutout[0, 0, 0] = 7 + + obj_mask = cutout == 5 + proj_z, proj_y, proj_x, size = object_projections_and_size_3D(cutout, 5) + + np.testing.assert_array_equal(proj_z, np.any(obj_mask, axis=0).astype(np.uint8)) + np.testing.assert_array_equal(proj_y, np.any(obj_mask, axis=1).astype(np.uint8)) + np.testing.assert_array_equal(proj_x, np.any(obj_mask, axis=2).astype(np.uint8)) + assert size == int(np.count_nonzero(obj_mask)) + + def test_missing_id_returns_zero_projections_and_zero_size(self): + cutout = np.zeros((3, 4, 5), dtype=np.uint32) + cutout[1, 2, 3] = 4 + proj_z, proj_y, proj_x, size = object_projections_and_size_3D(cutout, 8) + + np.testing.assert_array_equal(proj_z, np.zeros((4, 5), dtype=np.uint8)) + np.testing.assert_array_equal(proj_y, np.zeros((3, 5), dtype=np.uint8)) + np.testing.assert_array_equal(proj_x, np.zeros((3, 4), dtype=np.uint8)) + assert size == 0 + + def test_projection_dtype_is_uint8(self): + cutout = np.zeros((2, 3, 4), dtype=np.uint32) + cutout[1, 2, 3] = 7 + proj_z, proj_y, proj_x, _ = object_projections_and_size_3D(cutout, 7) + + assert proj_z.dtype == np.uint8 + assert proj_y.dtype == np.uint8 + assert proj_x.dtype == np.uint8 + + +class TestObjectProjectionAndSize3D: + @pytest.mark.parametrize("axis", [0, 1, 2]) + def test_matches_numpy_projection_for_axis(self, axis): + cutout = np.zeros((4, 5, 6), dtype=np.uint32) + cutout[0, 1, 2] = 5 + cutout[1, 1, 2] = 5 + cutout[2, 3, 4] = 5 + cutout[3, 0, 1] = 9 + + proj, size = object_projection_and_size_3D(cutout, 5, axis) + + expected_mask = cutout == 5 + expected_proj = np.any(expected_mask, axis=axis).astype(np.uint8) + np.testing.assert_array_equal(proj, expected_proj) + assert size == int(np.count_nonzero(expected_mask)) + + def test_invalid_axis_raises(self): + cutout = np.zeros((2, 3, 4), dtype=np.uint32) + with pytest.raises(ValueError): + object_projection_and_size_3D(cutout, 1, 3) + + +# --------------------------------------------------------------------------- +# calc_IoA_matrix_2D +# --------------------------------------------------------------------------- + +def _run_cython_ioa_2d(lab, prev_lab, use_union): + rp = regionprops(lab.astype(np.int32)) + prev_rp = regionprops(prev_lab.astype(np.int32)) + curr_IDs_arr = np.array([obj.label for obj in rp], dtype=np.uint32) + prev_IDs_arr = np.array([obj.label for obj in prev_rp], dtype=np.uint32) + prev_areas_arr = np.array([obj.area for obj in prev_rp], dtype=np.uint32) + if use_union: + rp_mapper = {obj.label: obj for obj in rp} + curr_areas_arr = np.array( + [rp_mapper[ID].area for ID in curr_IDs_arr.tolist()], dtype=np.uint32 + ) + else: + curr_areas_arr = np.empty(0, dtype=np.uint32) + return ( + calc_IoA_matrix_2D( + lab.astype(np.uint32), prev_lab.astype(np.uint32), + curr_IDs_arr, prev_IDs_arr, + prev_areas_arr, curr_areas_arr, use_union, + ), + rp, prev_rp, + ) + + +class TestCalcIoAMatrix2D: + @pytest.mark.parametrize("use_union", [False, True]) + def test_matches_python_reference(self, label_2d, label_2d_shifted, use_union): + mat_cy, rp, prev_rp = _run_cython_ioa_2d(label_2d_shifted, label_2d, use_union) + mat_py, _, _ = _python_ioa_matrix( + label_2d_shifted, label_2d, rp, prev_rp, use_union + ) + np.testing.assert_allclose(mat_cy, mat_py, rtol=1e-9, atol=1e-12, + err_msg=f"Mismatch for use_union={use_union}") + + def test_shape(self, label_2d, label_2d_shifted): + mat_cy, rp, prev_rp = _run_cython_ioa_2d(label_2d_shifted, label_2d, False) + assert mat_cy.shape == (len(rp), len(prev_rp)) + + def test_values_between_0_and_1(self, label_2d, label_2d_shifted): + mat_cy, _, _ = _run_cython_ioa_2d(label_2d_shifted, label_2d, False) + assert np.all(mat_cy >= 0.0) + assert np.all(mat_cy <= 1.0 + 1e-12) + + def test_no_overlap_gives_zero_matrix(self): + """Two images with disjoint objects should produce an all-zero IoA matrix.""" + lab = _make_disc_label_image((64, 64), [(1, 10, 10, 5)]) + prev_lab = _make_disc_label_image((64, 64), [(1, 50, 50, 5)]) + mat_cy, _, _ = _run_cython_ioa_2d(lab, prev_lab, False) + np.testing.assert_array_equal(mat_cy, np.zeros_like(mat_cy)) + + def test_identical_images_diagonal_is_one(self): + """When lab == prev_lab and IDs match, area_prev IoA should be 1.""" + lab = _make_disc_label_image((64, 64), [ + (1, 20, 20, 8), + (2, 20, 50, 8), + ]) + mat_cy, _, _ = _run_cython_ioa_2d(lab, lab, False) + np.testing.assert_allclose(np.diag(mat_cy), 1.0, rtol=1e-9) + + def test_dtype_is_float64(self, label_2d, label_2d_shifted): + mat_cy, _, _ = _run_cython_ioa_2d(label_2d_shifted, label_2d, False) + assert mat_cy.dtype == np.float64 + + +# --------------------------------------------------------------------------- +# calc_IoA_matrix_3D +# --------------------------------------------------------------------------- + +def _run_cython_ioa_3d(lab, prev_lab, use_union): + rp = regionprops(lab.astype(np.int32)) + prev_rp = regionprops(prev_lab.astype(np.int32)) + curr_IDs_arr = np.array([obj.label for obj in rp], dtype=np.uint32) + prev_IDs_arr = np.array([obj.label for obj in prev_rp], dtype=np.uint32) + prev_areas_arr = np.array([obj.area for obj in prev_rp], dtype=np.uint32) + if use_union: + rp_mapper = {obj.label: obj for obj in rp} + curr_areas_arr = np.array( + [rp_mapper[ID].area for ID in curr_IDs_arr.tolist()], dtype=np.uint32 + ) + else: + curr_areas_arr = np.empty(0, dtype=np.uint32) + return ( + calc_IoA_matrix_3D( + lab.astype(np.uint32), prev_lab.astype(np.uint32), + curr_IDs_arr, prev_IDs_arr, + prev_areas_arr, curr_areas_arr, use_union, + ), + rp, prev_rp, + ) + + +class TestCalcIoAMatrix3D: + @pytest.mark.parametrize("use_union", [False, True]) + def test_matches_python_reference(self, label_3d, label_3d_shifted, use_union): + mat_cy, rp, prev_rp = _run_cython_ioa_3d(label_3d_shifted, label_3d, use_union) + mat_py, _, _ = _python_ioa_matrix( + label_3d_shifted, label_3d, rp, prev_rp, use_union + ) + np.testing.assert_allclose(mat_cy, mat_py, rtol=1e-9, atol=1e-12, + err_msg=f"3D mismatch for use_union={use_union}") + + def test_shape(self, label_3d, label_3d_shifted): + mat_cy, rp, prev_rp = _run_cython_ioa_3d(label_3d_shifted, label_3d, False) + assert mat_cy.shape == (len(rp), len(prev_rp)) + + def test_values_between_0_and_1(self, label_3d, label_3d_shifted): + mat_cy, _, _ = _run_cython_ioa_3d(label_3d_shifted, label_3d, False) + assert np.all(mat_cy >= 0.0) + assert np.all(mat_cy <= 1.0 + 1e-12) + + def test_no_overlap_gives_zero_matrix(self): + lab = _make_sphere_label_image((20, 32, 32), [(1, 5, 8, 8, 2, 4, 4)]) + prev_lab = _make_sphere_label_image((20, 32, 32), [(1, 15, 24, 24, 2, 4, 4)]) + mat_cy, _, _ = _run_cython_ioa_3d(lab, prev_lab, False) + np.testing.assert_array_equal(mat_cy, np.zeros_like(mat_cy)) + + def test_identical_images_diagonal_is_one(self): + lab = _make_sphere_label_image((20, 32, 32), [ + (1, 8, 10, 10, 3, 4, 4), + (2, 8, 10, 22, 3, 4, 4), + ]) + mat_cy, _, _ = _run_cython_ioa_3d(lab, lab, False) + np.testing.assert_allclose(np.diag(mat_cy), 1.0, rtol=1e-9) + + def test_dtype_is_float64(self, label_3d, label_3d_shifted): + mat_cy, _, _ = _run_cython_ioa_3d(label_3d_shifted, label_3d, False) + assert mat_cy.dtype == np.float64 diff --git a/tests/test_regionprops_cutout_update.py b/tests/test_regionprops_cutout_update.py new file mode 100644 index 000000000..a5964fb48 --- /dev/null +++ b/tests/test_regionprops_cutout_update.py @@ -0,0 +1,328 @@ +import pytest +import numpy as np + +from cellacdc.regionprops import acdcRegionprops + + +def test_update_regionprops_via_cutout_reuses_cutout_object_with_global_coords(): + old_lab = np.zeros((12, 12), dtype=np.uint16) + old_lab[1:3, 1:3] = 1 + + new_lab = old_lab.copy() + new_lab[5:7, 6:9] = 2 + + rp = acdcRegionprops(old_lab) + rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(4, 5, 9, 10)) + + obj = rp.get_obj_from_ID(2) + + assert obj is not None + assert obj.bbox == (5, 6, 7, 9) + assert tuple((slc.start, slc.stop) for slc in obj.slice) == ((5, 7), (6, 9)) + assert obj.centroid == (5.5, 7.0) + np.testing.assert_array_equal( + obj.coords, + np.array( + [ + [5, 6], + [5, 7], + [5, 8], + [6, 6], + [6, 7], + [6, 8], + ] + ), + ) + assert new_lab[obj.slice].shape == obj.image.shape + assert np.all(new_lab[obj.slice][obj.image] == 2) + + +def test_update_regionprops_via_cutout_refreshes_preserved_id_image(): + old_lab = np.zeros((10, 10), dtype=np.uint16) + old_lab[2:4, 2:4] = 1 + + new_lab = old_lab.copy() + new_lab[2:5, 2:5] = 1 + + rp = acdcRegionprops(old_lab) + rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(1, 1, 6, 6)) + + obj = rp.get_obj_from_ID(1) + + assert obj is not None + assert obj.bbox == (2, 2, 5, 5) + np.testing.assert_array_equal(obj.image, new_lab[obj.slice] == 1) + np.testing.assert_array_equal(obj.coords, np.argwhere(new_lab == 1)) + + +def test_update_regionprops_via_cutout_batches_border_touching_ids(): + old_lab = np.zeros((14, 14), dtype=np.uint16) + old_lab[1:3, 1:3] = 1 + + new_lab = old_lab.copy() + new_lab[4:9, 5:8] = 2 + new_lab[7:11, 9:13] = 3 + + rp = acdcRegionprops(old_lab) + rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(6, 7, 10, 12)) + + obj2 = rp.get_obj_from_ID(2) + obj3 = rp.get_obj_from_ID(3) + + assert obj2 is not None + assert obj2.bbox == (4, 5, 9, 8) + np.testing.assert_array_equal(obj2.coords, np.argwhere(new_lab == 2)) + + assert obj3 is not None + assert obj3.bbox == (7, 9, 11, 13) + np.testing.assert_array_equal(obj3.coords, np.argwhere(new_lab == 3)) + + +def test_update_regionprops_via_assignments_rebinds_label_image(): + old_lab = np.zeros((8, 8), dtype=np.uint16) + old_lab[2:5, 3:6] = 1 + + new_lab = np.zeros_like(old_lab) + new_lab[2:5, 3:6] = 7 + + rp = acdcRegionprops(old_lab) + rp.update_regionprops_via_assignments({1: 7}, new_lab) + + obj = rp.get_obj_from_ID(7) + + assert obj is not None + assert rp.lab is new_lab + assert obj._label_image is new_lab + np.testing.assert_array_equal(obj.image, new_lab[obj.slice] == 7) + np.testing.assert_array_equal(obj.coords, np.argwhere(new_lab == 7)) + + +def test_slice_regionprops_are_lazy_and_initialized_on_access(): + lab = np.zeros((3, 6, 6), dtype=np.uint16) + lab[1, 2:4, 1:3] = 4 + + rp = acdcRegionprops(lab) + + assert rp._slice_rps['z'] == {} + assert rp._slice_rps['y'] == {} + assert rp._slice_rps['x'] == {} + + z1 = rp.get_slice_rp(1) + assert z1 is not None + assert 1 in rp._slice_rps['z'] + assert rp._slice_rps['z'][1] is z1 + assert z1.lab.ndim == 2 + assert z1.get_obj_from_ID(4) is not None + + +def test_projection_regionprops_are_lazy_and_initialized_on_access(): + lab = np.zeros((3, 6, 6), dtype=np.uint16) + lab[1, 2:4, 1:3] = 4 + + rp = acdcRegionprops(lab) + + assert rp._proj_rps['z'] == {} + assert rp._proj_rps['y'] == {} + assert rp._proj_rps['x'] == {} + + zmax = rp.get_proj_rp(kind='max', slicing='z') + assert zmax is not None + assert 'max' in rp._proj_rps['z'] + assert rp._proj_rps['z']['max'] is zmax + assert zmax.lab.ndim == 2 + assert zmax.get_obj_from_ID(4) is not None + + +def test_projection_regionprops_support_most_common_kind(): + lab = np.array( + [ + [[0, 1], [2, 2]], + [[0, 1], [2, 3]], + [[4, 1], [0, 3]], + ], + dtype=np.uint16, + ) + + rp = acdcRegionprops(lab) + most_common = rp.get_proj_rp(kind='most common', slicing='z') + + expected = np.array( + [[4, 1], [2, 3]], + dtype=np.uint16, + ) + np.testing.assert_array_equal(most_common.lab, expected) + assert rp.get_obj_from_proj_rp(3, kind='most common z-projection', warn=False) is not None + + +def test_most_common_projection_uses_local_cutout_update(monkeypatch): + old_lab = np.zeros((3, 6, 6), dtype=np.uint16) + old_lab[:, 1:4, 1:4] = 1 + + rp = acdcRegionprops(old_lab) + proj_before = rp.get_proj_rp(kind='most_common', slicing='z') + expected_before = rp._get_lab_projection(old_lab, slicing='z', kind='most_common') + np.testing.assert_array_equal(proj_before.lab, expected_before) + + new_lab = old_lab.copy() + new_lab[0:2, 2:5, 2:5] = 2 + + original_replace_cached = rp._replace_cached_lab_projection + + def _replace_cached_should_not_run_for_most_common(slicing, kind): + if kind == 'most_common': + raise AssertionError( + 'most_common projection should be updated locally for cutout updates.' + ) + return original_replace_cached(slicing, kind) + + monkeypatch.setattr(rp, '_replace_cached_lab_projection', _replace_cached_should_not_run_for_most_common) + + rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(2, 2, 5, 5)) + + proj_after = rp.get_proj_rp(kind='most_common', slicing='z') + expected_after = rp._get_lab_projection(new_lab, slicing='z', kind='most_common') + np.testing.assert_array_equal(proj_after.lab, expected_after) + + +@pytest.mark.parametrize('slicing', ['z', 'y', 'x']) +def test_most_common_projection_uses_local_cutout_update_for_all_slicings(monkeypatch, slicing): + old_lab = np.zeros((4, 7, 8), dtype=np.uint16) + old_lab[1:3, 1:4, 1:4] = 1 + old_lab[0:2, 4:6, 4:7] = 2 + + rp = acdcRegionprops(old_lab) + proj_before = rp.get_proj_rp(kind='most_common', slicing=slicing) + expected_before = rp._get_lab_projection(old_lab, slicing=slicing, kind='most_common') + np.testing.assert_array_equal(proj_before.lab, expected_before) + + new_lab = old_lab.copy() + new_lab[:, 2:6, 3:7] = np.array( + [ + [ + [0, 2, 2, 0], + [3, 3, 2, 0], + [3, 3, 2, 0], + [0, 0, 0, 0], + ], + [ + [0, 2, 2, 0], + [3, 3, 3, 0], + [3, 3, 3, 0], + [0, 0, 0, 0], + ], + [ + [0, 0, 0, 0], + [3, 3, 3, 0], + [3, 3, 3, 0], + [0, 0, 0, 0], + ], + [ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ], + ], + dtype=np.uint16, + ) + + original_replace_cached = rp._replace_cached_lab_projection + + def _replace_cached_should_not_run_for_most_common(slicing_name, kind): + if kind == 'most_common': + raise AssertionError( + 'most_common projection should be updated locally for cutout updates.' + ) + return original_replace_cached(slicing_name, kind) + + monkeypatch.setattr( + rp, + '_replace_cached_lab_projection', + _replace_cached_should_not_run_for_most_common, + ) + + rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(2, 3, 6, 7)) + + proj_after = rp.get_proj_rp(kind='most_common', slicing=slicing) + expected_after = rp._get_lab_projection(new_lab, slicing=slicing, kind='most_common') + np.testing.assert_array_equal(proj_after.lab, expected_after) + + +def test_get_obj_from_id_for_stored_slice_and_projection_rps(): + lab = np.zeros((4, 8, 8), dtype=np.uint16) + lab[1:3, 2:5, 3:6] = 2 + + rp = acdcRegionprops(lab) + + obj_slice = rp.get_obj_from_slice_rp(2, slice_number=1, slicing='z', warn=False) + assert obj_slice is not None + + obj_proj = rp.get_obj_from_proj_rp(2, kind='max', slicing='z', warn=False) + assert obj_proj is not None + + +def test_slice_regionprops_follow_assignments_and_deletions(): + lab = np.zeros((4, 8, 8), dtype=np.uint16) + lab[1:3, 2:5, 3:6] = 2 + + rp = acdcRegionprops(lab) + _ = rp.get_slice_rp(1, 'z') + _ = rp.get_slice_rp(2, 'y') + _ = rp.get_slice_rp(3, 'x') + _ = rp.get_proj_rp('max', 'z') + _ = rp.get_proj_rp('mean', 'y') + _ = rp.get_proj_rp('median', 'x') + + remapped_lab = np.zeros_like(lab) + remapped_lab[1:3, 2:5, 3:6] = 9 + rp.update_regionprops_via_assignments({2: 9}, remapped_lab) + + assert rp.get_slice_rp(1, 'z').get_obj_from_ID(9, warn=False) is not None + assert rp.get_slice_rp(2, 'y').get_obj_from_ID(9, warn=False) is not None + assert rp.get_slice_rp(3, 'x').get_obj_from_ID(9, warn=False) is not None + assert rp.get_proj_rp('max', 'z').get_obj_from_ID(9, warn=False) is not None + assert rp.get_slice_rp(1, 'z').get_obj_from_ID(2, warn=False) is None + + rp.update_regionprops_via_deletions({9}) + assert rp.get_slice_rp(1, 'z').get_obj_from_ID(9, warn=False) is None + assert rp.get_slice_rp(2, 'y').get_obj_from_ID(9, warn=False) is None + assert rp.get_slice_rp(3, 'x').get_obj_from_ID(9, warn=False) is None + assert rp.get_proj_rp('max', 'z').get_obj_from_ID(9, warn=False) is None + + +def test_slice_regionprops_update_from_2d_cutout_on_3d(): + old_lab = np.zeros((5, 12, 12), dtype=np.uint16) + old_lab[1:4, 2:4, 2:4] = 5 + + new_lab = old_lab.copy() + new_lab[3, 6:9, 7:10] = 6 + + rp = acdcRegionprops(old_lab) + _ = rp.get_slice_rp(3, 'z') + _ = rp.get_slice_rp(6, 'y') + _ = rp.get_slice_rp(7, 'x') + + # 2D cutout on y/x; implementation expands to all z for touched IDs. + rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(5, 6, 10, 11)) + + assert rp.get_obj_from_ID(6, warn=False) is not None + assert rp.get_slice_rp(3, 'z').get_obj_from_ID(6, warn=False) is not None + assert rp.get_slice_rp(6, 'y').get_obj_from_ID(6, warn=False) is not None + assert rp.get_slice_rp(7, 'x').get_obj_from_ID(6, warn=False) is not None + + +def test_projection_lab_sorted_draws_large_first_small_on_top(): + lab = np.zeros((4, 8, 8), dtype=np.uint16) + # Larger object + lab[0:4, 1:7, 1:7] = 1 + # Smaller object overlapping the center + lab[1:3, 3:5, 3:5] = 2 + + rp = acdcRegionprops(lab) + proj = rp.get_projection_lab_sorted(slicing='z') + + # Small object should be visible on top in overlap area. + np.testing.assert_array_equal(proj[3:5, 3:5], np.full((2, 2), 2, dtype=np.uint16)) + assert proj[2, 2] == 1 + + diff --git a/tests/test_segm_utils.py b/tests/test_segm_utils.py new file mode 100644 index 000000000..a8f4c850e --- /dev/null +++ b/tests/test_segm_utils.py @@ -0,0 +1,175 @@ +import types + +import numpy as np +from skimage.measure import regionprops +from types import SimpleNamespace + +from cellacdc import myutils +from cellacdc.regionprops import acdcRegionprops +from cellacdc.workers import SegForLostIDsWorker + +from cellacdc.models.thresholding.acdcSegment import Model as ThresholdingModel +from cellacdc.segm_utils import get_best_overlapping_label + + +def test_get_best_overlapping_label_uses_majority_overlap_with_allowed_labels(): + label_img = np.zeros((8, 8), dtype=np.uint16) + label_img[2:6, 2:4] = 4 + label_img[3:7, 4:6] = 7 + + obj = types.SimpleNamespace( + slice=(slice(2, 7), slice(2, 6)), + image=np.array( + [ + [False, False, False, False], + [True, True, False, False], + [True, True, True, True], + [True, True, True, True], + [False, False, True, True], + ] + ), + ) + + assert get_best_overlapping_label(label_img, obj, {4, 7}) == 4 + + +def test_get_best_overlapping_label_returns_none_without_allowed_overlap(): + label_img = np.zeros((6, 6), dtype=np.uint16) + label_img[1:3, 1:3] = 2 + + obj = types.SimpleNamespace( + slice=(slice(1, 4), slice(1, 4)), + image=np.array( + [ + [False, True, False], + [True, True, True], + [False, True, False], + ] + ), + ) + + assert get_best_overlapping_label(label_img, obj, {5}) is None + + +def test_thresholding_model_object_can_be_mapped_back_to_missing_id(): + prev_lab = np.zeros((10, 10), dtype=np.uint16) + prev_lab[3:7, 3:7] = 5 + + image = np.zeros((10, 10), dtype=np.float32) + image[3:7, 3:7] = 10.0 + + model = ThresholdingModel() + model_lab = model.segment( + image, + gauss_sigma=0, + threshold_method='threshold_otsu', + ) + + rp_model = regionprops(model_lab) + assert len(rp_model) == 1 + + recovered_id = get_best_overlapping_label(prev_lab, rp_model[0], {5}) + + assert recovered_id == 5 + + +class _DummyLogger: + def info(self, message): + pass + + def warning(self, message): + pass + + def error(self, message): + pass + + +class _DummySignal: + def emit(self, *args, **kwargs): + pass + + +class _DummySignals: + def __init__(self): + self.progress = _DummySignal() + self.finished = _DummySignal() + self.initProgressBar = _DummySignal() + self.progressBar = _DummySignal() + self.critical = _DummySignal() + + +def test_seg_for_lost_ids_worker_thresholding_relabels_recovered_object(monkeypatch): + prev_lab = np.zeros((10, 10), dtype=np.uint16) + prev_lab[3:7, 3:7] = 5 + + curr_lab = np.zeros((10, 10), dtype=np.uint16) + curr_img = np.zeros((10, 10), dtype=np.float32) + curr_img[3:7, 3:7] = 10.0 + + prev_rp = acdcRegionprops(prev_lab) + curr_rp = acdcRegionprops(curr_lab) + + posData = SimpleNamespace( + frame_i=1, + lab=curr_lab.copy(), + rp=curr_rp, + allData_li=[ + {'labels': prev_lab, 'regionprops': prev_rp}, + {'labels': curr_lab, 'regionprops': curr_rp}, + ], + ) + + guiWin = SimpleNamespace( + data=[posData], + pos_i=0, + SegForLostIDsSettings={ + 'models_settings': [ + { + 'base_model_name': 'thresholding', + 'init_kwargs_new': {}, + 'args_new': { + 'distance_filler_growth': 1.0, + 'overlap_threshold': 0.5, + 'padding': 1.0, + 'size_perc_diff': 1.0, + 'allow_only_tracked_cells': True, + }, + 'init_kwargs': {}, + 'model_kwargs': { + 'gauss_sigma': 0, + 'threshold_method': 'threshold_otsu', + }, + 'preproc_recipe': None, + 'applyPostProcessing': False, + 'standardPostProcessKwargs': {}, + 'customPostProcessFeatures': None, + 'customPostProcessGroupedFeatures': None, + } + ] + }, + getDisplayedImg1=lambda: curr_img, + get_2Dlab=lambda lab: lab, + getTrackedLostIDs=lambda: [], + setBrushID=lambda useCurrentLab=True, return_val=True: 10, + logger=_DummyLogger(), + ) + + worker = SegForLostIDsWorker(guiWin, mutex=SimpleNamespace(lock=lambda: None, unlock=lambda: None), waitCond=SimpleNamespace(wait=lambda mutex: None)) + worker.signals = _DummySignals() + worker.logger = _DummyLogger() + worker.gpu_go = True + worker.dont_force_cpu = True + + monkeypatch.setattr(worker, 'emitSigAskInit', lambda: None) + monkeypatch.setattr(worker, 'emitSigAskInstallGPU', lambda base_model_name, use_gpu: None) + monkeypatch.setattr(worker, 'emitSigUpdateRP', lambda wl_update=True, wl_track_og_curr=False: None) + monkeypatch.setattr(worker, 'emitSigStoreData', lambda autosave=True: None) + monkeypatch.setattr(worker, 'emitTrackManuallyAddedObject', lambda *args, **kwargs: None) + monkeypatch.setattr(myutils, 'import_segment_module', lambda base_model_name: SimpleNamespace(Model=ThresholdingModel)) + monkeypatch.setattr(myutils, 'init_segm_model', lambda acdcSegment, posData, init_kwargs_new: ThresholdingModel()) + + monkeypatch.setattr(worker, 'emitGetSegForLostIDsInputImg', lambda image_channel_name: curr_img) + worker.run() + + assert posData.lab[3:7, 3:7].min() == 5 + assert posData.lab[3:7, 3:7].max() == 5 \ No newline at end of file