diff --git a/.github/workflows/build_cython_extensions.yml b/.github/workflows/build_cython_extensions.yml
new file mode 100644
index 000000000..9853df048
--- /dev/null
+++ b/.github/workflows/build_cython_extensions.yml
@@ -0,0 +1,111 @@
+name: Build Cython extensions
+
+permissions:
+ contents: write
+
+on:
+ push:
+ paths:
+ - "cellacdc/**/*.pyx"
+ - "precompile_functions.py"
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ${{ matrix.os }}
+ defaults:
+ run:
+ working-directory: ${{ github.workspace }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ["3.10", "3.11", "3.12", "3.13"]
+
+ steps:
+ - uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+
+ - uses: actions/setup-python@v6
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Build extension
+ run: |
+ python -m pip install --upgrade "setuptools>=77" cython numpy
+ # Remove existing binaries so only the freshly built one remains for upload
+ python -c "
+ import os, glob
+ for f in glob.glob('cellacdc/precompiled/*.so') + glob.glob('cellacdc/precompiled/*.pyd'):
+ os.remove(f)
+ print(f'Removed existing: {f}')
+ "
+ python "${{ github.workspace }}/precompile_functions.py" build_ext --inplace --build-temp "${{ github.workspace }}/build/temp"
+ echo "Built binary:"
+ ls cellacdc/precompiled/
+
+ - uses: actions/upload-artifact@v6
+ with:
+ name: precompiled-${{ matrix.os }}-py${{ matrix.python-version }}
+ path: |
+ cellacdc/precompiled/*.so
+ cellacdc/precompiled/*.pyd
+
+ commit:
+ needs: build
+ runs-on: ubuntu-latest
+ permissions:
+ contents: write
+ defaults:
+ run:
+ working-directory: ${{ github.workspace }}
+ steps:
+ - uses: actions/checkout@v6
+ with:
+ ref: ${{ github.ref_name }}
+ fetch-depth: 0
+
+ - uses: actions/download-artifact@v6
+ with:
+ path: artifacts/
+
+ - name: Flatten artifacts into precompiled directory
+ run: |
+ mkdir -p cellacdc/precompiled
+ find artifacts -type f \( -name "*.pyd" -o -name "*.so" \) -exec cp {} cellacdc/precompiled/ \;
+ echo "Files copied:"
+ ls -la cellacdc/precompiled/
+ echo ""
+ echo "Total files:"
+ find cellacdc/precompiled/ -type f | grep -E '\.(pyd|so)$' | wc -l
+
+ - name: Validate artifact count
+ run: |
+ EXPECTED_COUNT=12
+ FILE_COUNT=$(find cellacdc/precompiled/ -type f \( -name "*.pyd" -o -name "*.so" \) | wc -l)
+ echo "Expected files: $EXPECTED_COUNT"
+ echo "Found files: $FILE_COUNT"
+ if [ "$FILE_COUNT" -lt "$EXPECTED_COUNT" ]; then
+ echo "Missing binaries: expected at least $EXPECTED_COUNT, found $FILE_COUNT"
+ exit 1
+ fi
+
+ - name: Commit precompiled binaries
+ run: |
+ git config user.name "github-actions[bot]"
+ git config user.email "github-actions[bot]@users.noreply.github.com"
+ find cellacdc/precompiled/ -type f \( -name "*.pyd" -o -name "*.so" \) -print0 | xargs -0 git add -f
+ git add cellacdc/precompiled/__init__.py
+ FILE_COUNT=$(find cellacdc/precompiled/ -type f \( -name "*.pyd" -o -name "*.so" \) | wc -l)
+ echo "Files to commit: $FILE_COUNT"
+ if [ "$FILE_COUNT" -eq 0 ]; then
+ echo "No precompiled files found, skipping commit"
+ exit 0
+ fi
+ if git diff --cached --quiet; then
+ echo "No changes to commit"
+ exit 0
+ fi
+ git commit -m "ci: update precompiled Cython extensions [skip ci]"
+ git pull --rebase origin ${{ github.ref_name }}
+ git push origin HEAD:${{ github.ref_name }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 6d90763c1..3b45e534f 100755
--- a/.gitignore
+++ b/.gitignore
@@ -26,6 +26,13 @@ requirements_new.txt
**/weights_location_path.txt
**/_test
+# cython generated files
+*.pyd
+*.so
+!cellacdc/precompiled/
+!cellacdc/precompiled/*.pyd
+!cellacdc/precompiled/*.so
+
# Test output plots
tests/_plots/
@@ -36,7 +43,6 @@ setup.cfg
# Starting from pip 21.3 setup.py is not needed anymore
# and we rely only on setup.cfg for env
-setup.py
environment.yml
# requirements.txt
conda_env_list_commands.txt
diff --git a/cellacdc/__init__.py b/cellacdc/__init__.py
index c4fcaf553..15355abdb 100755
--- a/cellacdc/__init__.py
+++ b/cellacdc/__init__.py
@@ -806,4 +806,17 @@ def inner_function(self, *args, **kwargs):
single_pos_index_cols = (
'experiment_folderpath',
'Position_n'
-)
\ No newline at end of file
+)
+
+precompiled_import_success = False
+try:
+ from cellacdc.precompiled.precompiled_functions import (
+ find_all_objects_2D,
+ find_all_objects_3D,
+ calc_IoA_matrix_2D,
+ calc_IoA_matrix_3D,
+ most_common_projection_3D,
+ )
+ precompiled_import_success = True
+except Exception:
+ pass
\ No newline at end of file
diff --git a/cellacdc/_warnings.py b/cellacdc/_warnings.py
index e313b3472..e5e2042f3 100644
--- a/cellacdc/_warnings.py
+++ b/cellacdc/_warnings.py
@@ -32,6 +32,34 @@ def warnTooManyItems(mainWin, numItems, qparent):
)
return msg.cancel, msg.clickedButton==switchToLowResButton
+def warnTooManyNewItems(mainWin, numItems, qparent):
+ from . import widgets
+ mainWin.logger.info(
+ '[WARNING]: asking user what to do with too many objects...'
+ )
+ msg = widgets.myMessageBox(wrapText=False)
+ txt = html_utils.paragraph(f"""
+ WARNING: The resulting segmentation mask has {numItems} objects.
+ Creating high resolution text annotations
+ for these many objects could take a long time.
+ We recommend deactivating text annotations or switching to low resolution annotations.
+ You can still try to switch to activate them or switch to high resolution later.
+ What do you want to do?
+ """)
+
+ _, switchToLowResButton, deactivateAnnotButton = msg.warning(
+ qparent, 'Too many objects', txt,
+ buttonsTexts=(
+ 'Cancel',
+ widgets.reloadPushButton(' Switch to low resolution '),
+ widgets.noPushButton(' Deactivate text annotations ')
+ )
+ )
+ switchToLowRes = msg.clickedButton==switchToLowResButton
+ deactivateAnnot = msg.clickedButton==deactivateAnnotButton
+
+ return msg.cancel, switchToLowRes, deactivateAnnot
+
def warnRestartCellACDCcolorModeToggled(
scheme, app_name='Cell-ACDC', parent=None
):
diff --git a/cellacdc/annotate.py b/cellacdc/annotate.py
index ec09b3610..cbe9f2c97 100644
--- a/cellacdc/annotate.py
+++ b/cellacdc/annotate.py
@@ -5,7 +5,7 @@
import pandas as pd
from . import GUI_INSTALLED
-from . import cellacdc_path, printl, ignore_exception
+from . import cellacdc_path, printl, ignore_exception, debugutils
if GUI_INSTALLED:
from PIL import Image, ImageFont, ImageDraw
@@ -169,9 +169,10 @@ def appendData(self, data, text):
self.annotData.append(data)
self.texts.append(text)
- def highlightObject(self, obj):
+ def highlightObject(self, obj, rp=None, getObjCentroidFunc=None):
self.highlighterItem.texts = self.texts
- self.highlighterItem.highlightObject(obj)
+ self.highlighterItem.highlightObject(
+ obj, rp=rp, getObjCentroidFunc=getObjCentroidFunc)
def grayOutAnnotations(self, IDsToSkip=None):
self.setOpacity(0.3)
@@ -207,6 +208,45 @@ def __init__(self, *args, anchor=(0.5, 0.5), **kargs):
self.texts = []
self.annotData = []
self._anchor = anchor
+
+ def _rebuildSizes(self, bold=False):
+ if bold:
+ self.sizesBold = plot.get_symbol_sizes(
+ self.scalesBold, self.symbolsBold, self.fontSize
+ )
+ self._maxScaleBold = max(self.scalesBold.values(), default=None)
+ else:
+ self.sizesRegular = plot.get_symbol_sizes(
+ self.scalesRegular, self.symbolsRegular, self.fontSize
+ )
+ self._maxScaleRegular = max(self.scalesRegular.values(), default=None)
+
+ def _updateSizesForTexts(self, texts, bold=False):
+ if not texts:
+ return
+
+ if bold:
+ scales = self.scalesBold
+ sizes_attr = 'sizesBold'
+ max_scale_attr = '_maxScaleBold'
+ else:
+ scales = self.scalesRegular
+ sizes_attr = 'sizesRegular'
+ max_scale_attr = '_maxScaleRegular'
+
+ current_max_scale = getattr(self, max_scale_attr, None)
+ if current_max_scale is None:
+ self._rebuildSizes(bold=bold)
+ return
+
+ added_max_scale = max(scales[text] for text in texts)
+ if added_max_scale > current_max_scale:
+ self._rebuildSizes(bold=bold)
+ return
+
+ sizes = getattr(self, sizes_attr)
+ for text in texts:
+ sizes[text] = int(np.round(self.fontSize*current_max_scale/scales[text]))
def clearData(self):
self.setData([], [])
@@ -254,15 +294,28 @@ def initSymbols(self, allIDs, onlyIDs=False):
self.createSymbols(annotTexts)
def addSymbols(self, annotTexts, includeBold=True):
- for text in annotTexts:
- if includeBold:
- self.symbolsBold[text] = self.getObjTextAnnotSymbol(
- text, bold=True, initSizes=False
+ if includeBold:
+ missing_bold = [
+ text for text in annotTexts if text not in self.symbolsBold
+ ]
+ if missing_bold:
+ symbolsBold, scalesBold = plot.texts_to_pg_scatter_symbols(
+ missing_bold, font=self.fontBold, return_scales=True
)
- self.symbolsRegular[text] = self.getObjTextAnnotSymbol(
- text, bold=True, initSizes=False
+ self.symbolsBold.update(symbolsBold)
+ self.scalesBold.update(scalesBold)
+ self._updateSizesForTexts(missing_bold, bold=True)
+
+ missing_regular = [
+ text for text in annotTexts if text not in self.symbolsRegular
+ ]
+ if missing_regular:
+ symbolsRegular, scalesRegular = plot.texts_to_pg_scatter_symbols(
+ missing_regular, font=self.fontRegular, return_scales=True
)
- self.initSizes(includeBold=includeBold)
+ self.symbolsRegular.update(symbolsRegular)
+ self.scalesRegular.update(scalesRegular)
+ self._updateSizesForTexts(missing_regular, bold=False)
def createSymbols(self, annotTexts, includeBold=True):
if includeBold:
@@ -281,12 +334,8 @@ def initSizes(self, includeBold=True):
includeBold = False
if includeBold:
- self.sizesBold = plot.get_symbol_sizes(
- self.scalesBold, self.symbolsBold, self.fontSize
- )
- self.sizesRegular = plot.get_symbol_sizes(
- self.scalesRegular, self.symbolsRegular, self.fontSize
- )
+ self._rebuildSizes(bold=True)
+ self._rebuildSizes(bold=False)
def setColors(self, colors):
self._colors = colors.copy()
@@ -325,7 +374,7 @@ def getObjTextAnnotSymbol(self, text, bold=False, initSizes=True):
symbols[text] = symbol
scales[text] = scale
if initSizes:
- self.initSizes()
+ self._updateSizesForTexts([text], bold=bold)
return symbol
def grayOutAnnotations(self, IDsToSkip=None):
@@ -346,7 +395,7 @@ def grayOutAnnotations(self, IDsToSkip=None):
self.setBrush(brushes)
self.setPen(pens)
- def highlightObject(self, obj):
+ def highlightObject(self, obj, rp=None, getObjCentroidFunc=None):
ID = obj.label
objIdx = None
for idx, objData in enumerate(self.data):
@@ -357,7 +406,14 @@ def highlightObject(self, obj):
objOpts = {
'text': str(ID), 'bold': True, 'color_name': 'new_object'
}
- yc, xc = obj.centroid[-2:]
+ if rp is not None:
+ centroid = rp.get_centroid(obj.label)
+ else:
+ centroid = obj.centroid
+ if getObjCentroidFunc is not None:
+ yc, xc = getObjCentroidFunc(centroid)
+ else:
+ yc, xc = centroid[-2:]
pos = (int(xc), int(yc))
self.addObjAnnot(pos, draw=True, **objOpts)
return
@@ -504,13 +560,20 @@ def removeFromPlotItem(self, ax):
if hasattr(self.item, 'highlighterItem'):
ax.removeItem(self.item.highlighterItem)
- def addObjAnnotation(self, obj, color_name, text, bold):
+ def addObjAnnotation(self, obj, color_name, text, bold, rp=None, getObjCentroidFunc=None):
objOpts = {
'text': text,
'bold': bold,
'color_name': color_name,
}
- yc, xc = obj.centroid[-2:]
+ if rp is not None:
+ centroid = rp.get_centroid(obj.label)
+ else:
+ centroid = obj.centroid
+ if getObjCentroidFunc is not None:
+ yc, xc = getObjCentroidFunc(centroid)
+ else:
+ yc, xc = centroid[-2:]
pos = (int(xc), int(yc))
objData = self.item.addObjAnnot(pos, draw=True, **objOpts)
self.item.appendData(objData, objOpts['text'])
@@ -563,7 +626,8 @@ def setAnnotations(self, **kwargs):
isGenNumTreeAnnotation, posData.frame_i
)
- yc, xc = getObjCentroidFunc(obj.centroid)
+ centroid = posData.rp.get_centroid(obj.label)
+ yc, xc = getObjCentroidFunc(centroid)
try:
rp_zslice = posData.zSlicesRp[currentZ]
obj_2d = rp_zslice[obj.label]
@@ -598,7 +662,8 @@ def setAnnotations(self, **kwargs):
'color_name': 'tracked_lost_object',
'bold': False,
}
- yc, xc = obj.centroid[-2:]
+ centroid = prev_rp.get_centroid(obj.label)
+ yc, xc = getObjCentroidFunc(centroid)
pos = (int(xc), int(yc))
objData = self.item.addObjAnnot(pos, draw=False, **objOpts)
self.item.appendData(objData, objOpts['text'])
@@ -624,7 +689,8 @@ def setAnnotations(self, **kwargs):
'color_name': 'lost_object',
'bold': False,
}
- yc, xc = getObjCentroidFunc(obj.centroid)
+ centroid = prev_rp.get_centroid(obj.label)
+ yc, xc = getObjCentroidFunc(centroid)
try:
pos = (int(xc), int(yc))
except Exception as err:
@@ -638,8 +704,8 @@ def setAnnotations(self, **kwargs):
self.item.draw()
- def highlightObject(self, obj):
- self.item.highlightObject(obj)
+ def highlightObject(self, obj, rp=None, getObjCentroidFunc=None):
+ self.item.highlightObject(obj, rp=rp, getObjCentroidFunc=getObjCentroidFunc)
def removeHighlightObject(self, obj):
self.item.removeHighlightObject(obj)
diff --git a/cellacdc/apps.py b/cellacdc/apps.py
index 5bc4553c0..dc4cd369b 100755
--- a/cellacdc/apps.py
+++ b/cellacdc/apps.py
@@ -94,18 +94,13 @@
from . import io
from . import cca_functions
from . import path
+from . import fonts
POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False)
TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet()
LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet()
BACKGROUND_RGBA = _palettes.get_disabled_colors()['Button']
-font = QFont()
-font.setPixelSize(12)
-italicFont = QFont()
-italicFont.setPixelSize(12)
-italicFont.setItalic(True)
-
class ArgWidget:
def __init__(self, name, type, widget, defaultVal, valueSetter, valueGetter, changeSig=None):
self.name = name
@@ -838,7 +833,7 @@ def __init__(
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def addCentroidsSection(self, row, layout, **kwargs):
sectionWidgets = []
@@ -1441,7 +1436,7 @@ def __init__(self, parent=None):
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def restoreState(self, state):
self.appearanceGroupbox.restoreState(state)
@@ -1588,7 +1583,7 @@ def __init__(
layout.addLayout(buttonsLayout)
self.setLayout(layout)
- self.setFont(font)
+ self.setFont(fonts.font)
if defaultEntry:
self.updateFilename(defaultEntry)
@@ -1856,7 +1851,7 @@ def __init__(self, basename='', parent=None):
mainLayout.addLayout(buttonsLayout)
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def createThirdSegmToggled(self, checked):
self.appendTextWidget.setDisabled(not checked)
@@ -3026,9 +3021,7 @@ def __init__(
self.imageViewer = None
super().__init__(parent)
self.setWindowTitle(title)
- font = QFont()
- font.setPixelSize(12)
- self.setFont(font)
+ self.setFont(fonts.font)
mainLayout = QVBoxLayout()
entriesLayout = QGridLayout()
@@ -3922,7 +3915,7 @@ def __init__(self, parent=None):
layout.addLayout(buttonsLayout)
layout.addStretch(1)
self.setLayout(layout)
- self.setFont(font)
+ self.setFont(fonts.font)
def showInfo(self):
msg = widgets.myMessageBox(wrapText=False)
@@ -4114,7 +4107,7 @@ def __init__(
layout.addLayout(buttonsLayout)
layout.addStretch(1)
self.setLayout(layout)
- self.setFont(font)
+ self.setFont(fonts.font)
def selectFeatures(self):
features = measurements.get_btrack_features()
@@ -4340,7 +4333,7 @@ def __init__(self, posData=None, parent=None):
layout.addLayout(buttonsLayout)
layout.addStretch(1)
self.setLayout(layout)
- self.setFont(font)
+ self.setFont(fonts.font)
def methodChanged(self, method):
if method == 'mothermachine':
@@ -4565,7 +4558,7 @@ def __init__(
self.loop = None
self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint)
- self.setFont(font)
+ self.setFont(fonts.font)
def ok_cb(self, checked=False):
self.cancel = False
@@ -4741,7 +4734,7 @@ def __init__(self, fileName, folderPath, readPatternFunc=None, parent=None):
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def segmFolderpathSelected(self, path):
self.segmFolderPathEntry.setText(path)
@@ -4941,7 +4934,7 @@ def __init__(self, parent=None, isSegm3D=True):
cancelButton.clicked.connect(self.close)
self.setLayout(layout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.configPars = self.loadLastSelection()
@@ -5102,7 +5095,7 @@ def __init__(self, df: pd.DataFrame, parent=None):
)
self.setLayout(self.mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def saveSelection(self):
saved_selections = io.get_saved_moth_bud_tot_selections()
@@ -5507,7 +5500,7 @@ def __init__(self, df, parent=None):
self.mainLayout.addLayout(buttonsLayout)
self.setLayout(self.mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def ok_cb(self):
self.cancel = False
@@ -5577,50 +5570,66 @@ def ok_cb(self):
self.model_name = self.listBox.currentItem().text()
self.close()
-
class QDialogSelectModel(QDialog):
def __init__(
- self, parent=None, addSkipSegmButton=False, customFirst=''
+ self, parent=None, addSkipSegmButton=False, customFirst='',
+ allowMultiSelection=False, lastSelection=None,
+ addSelectLastSelectionButton=False,
+ addSelectLastRecipeButton=False,
+ custom_title=None,
+ info_label='',
):
self.cancel = True
+ self.loadLastRecipe = False
super().__init__(parent)
self.setWindowTitle('Select model')
+ self.info_label = info_label
- mainLayout = QVBoxLayout()
- topLayout = QVBoxLayout()
- bottomLayout = QHBoxLayout()
+ self.allowMultiSelection = allowMultiSelection
+ self.lastSelection = []
+ for m in (lastSelection or []):
+ if not isinstance(m, str):
+ continue
+ if m == 'thresholding':
+ m = 'Automatic thresholding'
+ self.lastSelection.append(m)
+ mainLayout = QVBoxLayout()
self.mainLayout = mainLayout
+ title = custom_title or 'Select model to use for segmentation: '
+
+ titleContainer = QWidget(self)
+ titleLayout = QGridLayout(titleContainer)
+ titleLayout.setContentsMargins(0, 0, 0, 0)
+ titleLayout.setSpacing(0)
label = QLabel(html_utils.paragraph(
- 'Select model to use for segmentation: '
+ title
))
- # padding: top, left, bottom, right
label.setStyleSheet("padding:0px 0px 3px 0px;")
- topLayout.addWidget(label, alignment=Qt.AlignCenter)
-
- listBox = widgets.listWidget()
- models = myutils.get_list_of_models()
+ titleLayout.addWidget(label, 0, 0, Qt.AlignCenter)
+ if info_label:
+ moreInfoButton = widgets.infoPushButton()
+ moreInfoButton.clicked.connect(self.showInfoLabel)
+ moreInfoButton.setSizePolicy(
+ QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed
+ )
+ titleLayout.addWidget(moreInfoButton, 0, 0, Qt.AlignTop | Qt.AlignRight)
+ mainLayout.addWidget(titleContainer)
- if customFirst:
- try:
- idx = models.index(customFirst)
- models.insert(0, models.pop(idx))
- except ValueError:
- print(f'Warning: {customFirst} not found in models list.')
- pass
+ self.modelSelector = widgets.ModelSelectionWidget(
+ parent=self,
+ customFirst=customFirst,
+ allowMultiSelection=allowMultiSelection,
+ )
+ # Convenience aliases kept for backward compatibility
+ self.listBox = self.modelSelector.listBox
+ mainLayout.addWidget(self.modelSelector)
- listBox.setFont(font)
- listBox.addItems(models)
- addCustomModelItem = QListWidgetItem('Add custom model...')
- addCustomModelItem.setFont(italicFont)
- listBox.addItem(addCustomModelItem)
- listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
- listBox.setCurrentRow(0)
- self.listBox = listBox
- listBox.itemDoubleClicked.connect(self.ok_cb)
- topLayout.addWidget(listBox)
+ if not allowMultiSelection:
+ self.listBox.itemDoubleClicked.connect(self.ok_cb)
+ bottomLayout = QHBoxLayout()
cancelButton = widgets.cancelPushButton('Cancel')
okButton = widgets.okPushButton(' Ok ')
okButton.setShortcut(Qt.Key_Enter)
@@ -5632,53 +5641,178 @@ def __init__(
skipSegmButton = widgets.SkipPushButton('Skip segmentation')
bottomLayout.addWidget(skipSegmButton)
skipSegmButton.clicked.connect(self.skipSegm)
+ if addSelectLastSelectionButton and allowMultiSelection:
+ selectLastSelButton = widgets.reloadPushButton('Load last selection...')
+ selectLastSelButton.clicked.connect(self.selectLastSelection)
+ selectLastSelButton.setEnabled(bool(self.lastSelection))
+ bottomLayout.addWidget(selectLastSelButton)
+ if addSelectLastRecipeButton and allowMultiSelection:
+ selectLastRecipeButton = widgets.reloadPushButton('Load last recipe...')
+ selectLastRecipeButton.clicked.connect(self.selectLastRecipe)
+ selectLastRecipeButton.setEnabled(bool(self.lastSelection))
+ bottomLayout.addWidget(selectLastRecipeButton)
+ if allowMultiSelection:
+ addCustomModelButton = widgets.addPushButton('Add custom model...')
+ addCustomModelButton.clicked.connect(self.addCustomModel)
+ bottomLayout.addWidget(addCustomModelButton)
bottomLayout.addWidget(okButton)
bottomLayout.setContentsMargins(0, 10, 0, 0)
- mainLayout.addLayout(topLayout)
mainLayout.addLayout(bottomLayout)
self.setLayout(mainLayout)
- # Connect events
okButton.clicked.connect(self.ok_cb)
cancelButton.clicked.connect(self.cancel_cb)
- self.setStyleSheet(LISTWIDGET_STYLESHEET)
-
+
+ @property
+ def selectionSequence(self):
+ return self.modelSelector.selectionSequence
+
+ @property
+ def modelItemsMap(self):
+ return self.modelSelector.modelItemsMap
+
def skipSegm(self):
self.cancel = False
self.selectedModel = 'skip_segmentation'
self.close()
-
+
+ def selectLastSelection(self):
+ if not self.lastSelection:
+ return
+ self.modelSelector.setSelectionFromList(self.lastSelection)
+
+ def selectLastRecipe(self):
+ if not self.lastSelection:
+ return
+ self.selectLastSelection()
+ self.cancel = False
+ self.loadLastRecipe = True
+ self.selectedModel = self.lastSelection.copy()
+ self.close()
+
+ def _runAddCustomModelWorkflow(self):
+ modelFilePath = addCustomModelMessages(self)
+ if modelFilePath is None:
+ return None
+
+ myutils.store_custom_model_path(modelFilePath)
+ modelName = os.path.basename(os.path.dirname(modelFilePath))
+ self.modelSelector.registerCustomModel(modelName)
+ return modelName
+
+ def addCustomModel(self):
+ modelName = self._runAddCustomModelWorkflow()
+ if modelName is None:
+ return
+
+ if self.allowMultiSelection:
+ self.modelSelector.addModelSelection(modelName)
+ else:
+ item = QListWidgetItem(modelName)
+ self.listBox.addItem(item)
+ self.listBox.setCurrentItem(item)
+
+ def showInfoLabel(self):
+ if not self.info_label:
+ return
+ msg = widgets.myMessageBox(showCentered=False, wrapText=False)
+ txt = html_utils.paragraph(self.info_label)
+ msg.information(self, 'More info', txt)
+
def keyPressEvent(self, event: QKeyEvent) -> None:
if event.key() == Qt.Key_Escape:
event.ignore()
return
-
super().keyPressEvent(event)
- def ok_cb(self, event):
+ def askSelectedModelsOrder(self, selected_models):
+ dialog = QDialog(self)
+ dialog.setWindowTitle('Order selected models')
+ dialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint)
+
+ layout = QVBoxLayout(dialog)
+ infoTxt = html_utils.paragraph(
+ 'Drag and drop to change the order of the selected models.
'
+ 'The top model will run first.'
+ )
+ layout.addWidget(QLabel(infoTxt))
+
+ modelOrderView = widgets.ReorderableListView(
+ selected_models, parent=dialog, isSingleSelection=True
+ )
+ layout.addWidget(modelOrderView)
+
+ buttonsLayout = QHBoxLayout()
+ cancelButton = widgets.cancelPushButton('Cancel')
+ okButton = widgets.okPushButton('Ok')
+ buttonsLayout.addStretch(1)
+ buttonsLayout.addWidget(cancelButton)
+ buttonsLayout.addSpacing(20)
+ buttonsLayout.addWidget(okButton)
+ layout.addLayout(buttonsLayout)
+
+ cancelButton.clicked.connect(dialog.reject)
+ okButton.clicked.connect(dialog.accept)
+
+ if dialog.exec_() != QDialog.Accepted:
+ return None
+
+ return modelOrderView.items()
+
+ def ok_cb(self, event=None):
self.clickedButton = self.sender()
- self.cancel = False
- item = self.listBox.currentItem()
+
+ if self.allowMultiSelection:
+ if not self.selectionSequence:
+ return
+
+ selected_models = list(self.selectionSequence)
+ if len(selected_models) > 1:
+ ordered_models = self.askSelectedModelsOrder(selected_models)
+ if ordered_models is None:
+ return
+ selected_models = ordered_models
+
+ self.selectedModel = selected_models
+ self.cancel = False
+ self.close()
+ return
+
+ selected_items = self.listBox.selectedItems()
+ if not selected_items:
+ return
+
+ selected_models = [item.text() for item in selected_items]
+ if len(selected_models) > 1:
+ ordered_models = self.askSelectedModelsOrder(selected_models)
+ if ordered_models is None:
+ return
+ self.selectedModel = ordered_models
+ self.cancel = False
+ self.close()
+ return
+
+ item = selected_items[0]
model = item.text()
if model == 'Add custom model...':
- modelFilePath = addCustomModelMessages(self)
- if modelFilePath is None:
+ modelName = self._runAddCustomModelWorkflow()
+ if modelName is None:
return
- myutils.store_custom_model_path(modelFilePath)
- modelName = os.path.basename(os.path.dirname(modelFilePath))
item = QListWidgetItem(modelName)
self.listBox.addItem(item)
self.listBox.setCurrentItem(item)
elif model == 'Automatic thresholding':
- self.selectedModel = 'thresholding'
+ self.selectedModel = model
+ self.cancel = False
self.close()
else:
self.selectedModel = model
+ self.cancel = False
self.close()
- def cancel_cb(self, event):
+ def cancel_cb(self, event=None):
self.cancel = True
self.selectedModel = None
self.close()
@@ -5725,7 +5859,7 @@ def __init__(self, text, parent=None):
mainLayout.addLayout(buttonsLayout)
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
class startStopFramesDialog(QBaseDialog):
def __init__(
@@ -5760,7 +5894,7 @@ def __init__(
okButton.clicked.connect(self.ok_cb)
cancelButton.clicked.connect(self.close)
- self.setFont(font)
+ self.setFont(fonts.font)
def ok_cb(self):
if self.selectFramesGroupbox.warningLabel.text():
@@ -6186,7 +6320,8 @@ def __init__(
self.addAdditionalValues(additionalValues)
self.setLayout(mainLayout)
- self.setFont(font)
+ if font is not None:
+ self.setFont(font)
# self.setModal(True)
def showWhySizeTisGrayed(self):
@@ -6723,9 +6858,7 @@ def __init__(self, mainWindow):
seeHereLabel.setTextFormat(Qt.RichText)
seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction)
seeHereLabel.setOpenExternalLinks(True)
- font = QFont()
- font.setPixelSize(12)
- seeHereLabel.setFont(font)
+ seeHereLabel.setFont(fonts.font)
seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;")
paramsLayout.addWidget(seeHereLabel, row, 0, 1, 2)
@@ -7112,7 +7245,7 @@ def __init__(
layout.addLayout(buttonsLayout, 2, 1)
self.setLayout(layout)
- self.setFont(font)
+ self.setFont(fonts.font)
def copyErrorMessage(self):
cb = QApplication.clipboard()
@@ -7444,7 +7577,7 @@ def setPosData(self):
# self.img.setCurrentPosIndex(self.pos_i)
# self.img.minMaxValuesMapper = self.mainWin.img1.minMaxValuesMapper
self.origLab = self.posData.lab.copy()
- self.origRp = skimage.measure.regionprops(self.origLab)
+ self.origRp = skimage.measure.regionprops(self.origLab) # why seperate rp here?
self.origObjs = {obj.label:obj for obj in self.origRp}
def valueChanged(self, value):
@@ -7459,7 +7592,6 @@ def apply(self, origLab=None):
if ccaAnnotRemoved:
self.mainWin.updateAllImages()
-
if origLab is None:
origLab = self.origLab.copy()
@@ -8030,7 +8162,7 @@ def addAlphaScrollbar(self, channelName, imageItem, alphaScrollBar=None):
if alphaScrollBar is None:
alphaScrollBar = QScrollBar(Qt.Horizontal)
label = QLabel(f'Alpha {channelName}')
- label.setFont(font)
+ label.setFont(fonts.font)
label.hide()
alphaScrollBar.imageItem = imageItem
alphaScrollBar.label = label
@@ -8585,7 +8717,7 @@ def __init__(self, expPaths: dict, infoPaths: dict=None, parent=None):
QAbstractItemView.SelectionMode.ExtendedSelection
)
self.treeWidget.setHeaderHidden(True)
- self.treeWidget.setFont(font)
+ self.treeWidget.setFont(fonts.font)
for exp_path, positions in expPaths.items():
pathLevels = exp_path.split(os.sep)
posFoldersInfo = None
@@ -9522,7 +9654,7 @@ def __init__(
entryWidget.setText(defaultTxt)
if not self.allowText:
entryWidget.textChanged[str].connect(self.onTextChanged)
- entryWidget.setFont(font)
+ entryWidget.setFont(fonts.font)
entryWidget.setAlignment(Qt.AlignCenter)
self.entryWidget = entryWidget
@@ -9530,7 +9662,7 @@ def __init__(
if allowedValues is not None:
notValidLabel = QLabel()
notValidLabel.setStyleSheet('color: red')
- notValidLabel.setFont(font)
+ notValidLabel.setFont(fonts.font)
notValidLabel.setAlignment(Qt.AlignCenter)
self.notValidLabel = notValidLabel
@@ -10056,7 +10188,7 @@ def __init__(
listBox.addItems(items)
listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection)
listBox.setCurrentRow(0)
- listBox.setFont(font)
+ listBox.setFont(fonts.font)
topLayout.addWidget(listBox)
listBox.hide()
self.ListBox = listBox
@@ -10078,7 +10210,7 @@ def __init__(
if showInFileManagerPath is not None:
showInFileManagerButton.clicked.connect(self.showInFileManager)
- self.setFont(font)
+ self.setFont(fonts.font)
def setSelectedItems(self, selectedItemsText):
if self.multiPosButton.isChecked():
@@ -10976,9 +11108,7 @@ def __init__(self, filename, SizeZ, filenamesWithInfo, parent=None):
self.setLayout(mainLayout)
- font = QFont()
- font.setPixelSize(12)
- self.setFont(font)
+ self.setFont(fonts.font)
# self.setModal(True)
@@ -11839,7 +11969,7 @@ def __init__(
printl(traceback.format_exc())
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
# self.setModal(True)
def warningNoSegmRecipes(self):
@@ -13046,7 +13176,7 @@ def __init__(
metricsTreeWidget = QTreeWidget()
metricsTreeWidget.setHeaderHidden(True)
- metricsTreeWidget.setFont(font)
+ metricsTreeWidget.setFont(fonts.font)
self.metricsTreeWidget = metricsTreeWidget
for chName in allChNames:
@@ -13237,7 +13367,7 @@ def __init__(
testButton.clicked.connect(self.test_cb)
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.setStyleSheet(TREEWIDGET_STYLESHEET)
@@ -13498,7 +13628,7 @@ def __init__(self, posDatas, parent=None):
_spinBox = QSpinBox()
_spinBox.setMaximum(214748364)
_spinBox.setAlignment(Qt.AlignCenter)
- _spinBox.setFont(font)
+ _spinBox.setFont(fonts.font)
if posData.acdc_df is not None:
_val = posData.acdc_df.index.get_level_values(0).max()+1
else:
@@ -13617,7 +13747,7 @@ def __init__(self, acdcDfs, allChNames, parent=None, debug=False):
for i, (acdc_df_endname, acdc_df) in enumerate(acdcDfs.items()):
metricsTreeWidget = QTreeWidget()
metricsTreeWidget.setHeaderHidden(True)
- metricsTreeWidget.setFont(font)
+ metricsTreeWidget.setFont(fonts.font)
classified_metrics = measurements.classify_acdc_df_colnames(
acdc_df, allChNames
@@ -13805,7 +13935,7 @@ def __init__(self, acdcDfs, allChNames, parent=None, debug=False):
# self.newColNameLineEdit.editingFinished.connect(self.equationChanged)
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.setStyleSheet(TREEWIDGET_STYLESHEET)
@@ -13993,7 +14123,7 @@ def __init__(
row += 1
self.equationsList = widgets.TreeWidget()
- self.equationsList.setFont(font)
+ self.equationsList.setFont(fonts.font)
self.equationsList.setHeaderLabels(['Metric', 'Expression'])
self.equationsList.setSelectionMode(
QAbstractItemView.SelectionMode.ExtendedSelection)
@@ -14302,7 +14432,7 @@ def __init__(
mainLayout.addSpacing(20)
mainLayout.addLayout(buttonsLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.setLayout(mainLayout)
def checkDuplicateShortcuts(self, text):
@@ -14429,7 +14559,7 @@ def __init__(self, posData, parent=None):
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def ok_cb(self):
self.cancel = False
@@ -14645,7 +14775,7 @@ def __init__(
self.setLayout(self._layout)
- # self.setFont(font)
+ # self.setFont(fonts.font)
self.addButton.clicked.connect(self.addFeatureField)
@@ -14928,7 +15058,7 @@ def __init__(
mainLayout.addStretch()
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.unitCombobox.currentTextChanged.connect(self.updateLengthUnit)
self.colorButton.clicked.disconnect()
@@ -15047,7 +15177,7 @@ def __init__(
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
def _warnNonUniqueCategories(self, category_1, category_2):
txt = html_utils.paragraph(f"""
@@ -15108,7 +15238,7 @@ def __init__(
metricsTreeWidget = QTreeWidget()
metricsTreeWidget.setHeaderHidden(True)
- metricsTreeWidget.setFont(font)
+ metricsTreeWidget.setFont(fonts.font)
self.metricsTreeWidget = metricsTreeWidget
for groupName, features in features_groups.items():
@@ -15153,7 +15283,7 @@ def __init__(
metricsTreeWidget.itemDoubleClicked.connect(self.addFeatureName)
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.setStyleSheet(TREEWIDGET_STYLESHEET)
@@ -15439,7 +15569,7 @@ def __init__(self, parent=None, title='Input'):
self.buttonsLayout = buttonsLayout
- self.setFont(font)
+ self.setFont(fonts.font)
self.setLayout(self.mainLayout)
def askText(self, prompt, infoText='', allowEmpty=False):
@@ -15909,7 +16039,7 @@ def __init__(self, parent=None, **properties):
mainLayout.addStretch()
self.setLayout(mainLayout)
- self.setFont(font)
+ self.setFont(fonts.font)
self.colorButton.clicked.disconnect()
self.colorButton.clicked.connect(self.selectColor)
@@ -19313,7 +19443,7 @@ def __init__(
self.setAcceptDrops(True)
- self.setFont(font)
+ self.setFont(fonts.font)
def dragEnterEvent(self, event):
event.acceptProposedAction()
@@ -19353,12 +19483,59 @@ def expFolderToPosFoldernamesMapper(self):
return expPathsPosFoldernamesMapper
def ok_cb(self):
- self.cancel = False
+ #verify all selected folders have Images folder:
+ faultyFolders = []
+ for path, selected_pos in self.expFolderToPosFoldernamesMapper().items():
+ if selected_pos == ['']:
+ images_path = myutils.get_images_folderpath(path)
+ if images_path is None or not os.path.exists(images_path):
+ faultyFolders.append(path)
+
+ else:
+ for pos in selected_pos:
+ pos_path = os.path.join(path, pos)
+ images_path = myutils.get_images_folderpath(pos_path)
+ if images_path is None or not os.path.exists(images_path):
+ faultyFolders.append(pos_path)
+
+ if faultyFolders:
+ self.warnNoAllValid(faultyFolders)
+ return
+
self.paths = self.pathsList()
self.selectedExpFolderToPosFoldernamesMapper = (
self.expFolderToPosFoldernamesMapper()
)
+ if not self.selectedExpFolderToPosFoldernamesMapper:
+ self.warnEmptySelection()
+ return
+ self.cancel = False
+
self.close()
+
+ def warnNoAllValid(self, faultyFolders=None):
+ msg = widgets.myMessageBox(wrapText=False)
+ txt = html_utils.paragraph(f"""
+ Some of the selected folders (see below) do not contain an Images folder.
+ Please, make sure to select Position folders, the Images folder inside Position folders, or any folder containing Position folders as sub-directories.
+ Thank you for your patience!
+ Selected folders:
+
+ {''.join(f'- {folder}
' for folder in faultyFolders)}
+
+ """)
+ msg.warning(
+ self, 'Some folders are not valid', txt
+ )
+
+ def warnEmptySelection(self):
+ msg = widgets.myMessageBox(wrapText=False)
+ txt = html_utils.paragraph("""
+ No folder was selected.
+ """)
+ msg.warning(
+ self, 'No folder selected', txt
+ )
def warnNoValidPathsFound(self, selected_path):
msg = widgets.myMessageBox(wrapText=False)
@@ -19427,11 +19604,11 @@ def addFolderPath(self, selected_path):
myutils.addToRecentPaths(selected_path)
folder_type = myutils.determine_folder_type(selected_path)
- is_pos_folder, is_images_folder, folder_path = folder_type
+ is_pos_folder, is_images_folder, folder_path = folder_type
if is_pos_folder:
paths = [selected_path]
elif is_images_folder:
- paths = [os.path.dirname(selected_path)]
+ paths = [os.path.dirname(selected_path) if selected_path.endswith('Images') else selected_path]
elif self.scanTree:
print(f'Scanning selected folder "{selected_path}"...')
exp_paths = path.get_posfolderpaths_walk(selected_path)
diff --git a/cellacdc/core.py b/cellacdc/core.py
index 80ae1df0e..d34733dc2 100755
--- a/cellacdc/core.py
+++ b/cellacdc/core.py
@@ -50,6 +50,7 @@
from . import measurements
from . import favourite_func_metrics_csv_path
from . import default_index_cols
+from . import regionprops
from ._types import (
ChannelsDict
@@ -756,6 +757,19 @@ def get_obj_contours(
only_longest_contour=True,
local=False,
):
+ if obj is not None and obj_image is None and not local:
+ if all_external and not all:
+ obj_image = np.ascontiguousarray(obj.image, dtype=np.uint8)
+ contours, _ = cv2.findContours(
+ obj_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
+ )
+ min_y, min_x, _, _ = obj.bbox
+ return [np.squeeze(cont, axis=1)+[min_x, min_y] for cont in contours]
+ if all and hasattr(obj, 'contour_all'):
+ return obj.contour_all
+ if only_longest_contour and hasattr(obj, 'contour'):
+ return obj.contour
+
if all:
retrieveMode = cv2.RETR_CCOMP
else:
@@ -1273,8 +1287,12 @@ def cca_df_to_acdc_df(cca_df, rp, acdc_df=None):
IDs.append(obj.label)
is_cell_dead_li.append(0)
is_cell_excluded_li.append(0)
- xx_centroid.append(int(obj.centroid[1]))
- yy_centroid.append(int(obj.centroid[0]))
+ if isinstance(rp, regionprops.acdcRegionprops):
+ centroid = rp.get_centroid(obj.label, exact=True)
+ else:
+ centroid = obj.centroid
+ xx_centroid.append(int(centroid[1]))
+ yy_centroid.append(int(centroid[0]))
acdc_df = pd.DataFrame({
'Cell_ID': IDs,
'is_cell_dead': is_cell_dead_li,
@@ -3036,16 +3054,25 @@ def split_connected_components(lab, rp=None, max_ID=None):
return split_occured
def split_along_convexity_defects(
- ID, lab, max_ID, max_i=1, eps_percent=0.01
+ ID, lab, max_ID, max_i=1, eps_percent=0.01, rp=None
):
- lab_ID_bool = lab == ID
+ if rp is not None:
+ obj = rp.get_obj_from_ID(ID)
+ lab_ID_bool = np.zeros_like(lab[obj.slice], dtype=bool)
+ lab_ID_bool[obj.image] = True
+ else:
+ lab_ID_bool = lab == ID
# First try separating by labelling
lab_ID = lab_ID_bool.astype(int)
rp_ID = skimage.measure.regionprops(lab_ID)
split_occured = split_connected_components(lab_ID, rp=rp_ID, max_ID=max_ID)
if split_occured:
success = True
- lab[lab_ID_bool] = lab_ID[lab_ID_bool]
+ if rp is not None:
+ lab[obj.slice][obj.image] = lab_ID[obj.image]
+ else:
+ lab[lab_ID_bool] = lab_ID[lab_ID_bool]
+
rp_ID = skimage.measure.regionprops(lab_ID)
separateIDs = [obj.label for obj in rp_ID]
return lab, success, separateIDs
@@ -3093,7 +3120,10 @@ def split_along_convexity_defects(
sep_bud_label = temp_sep_bud_lab
sep_bud_label_mask = sep_bud_label != 0
# plt.imshow_tk(sep_bud_label, dots_coords=np.asarray(defects_points))
- lab[sep_bud_label_mask] = sep_bud_label[sep_bud_label_mask]
+ if rp is not None:
+ lab[obj.slice][sep_bud_label_mask] = sep_bud_label[sep_bud_label_mask]
+ else:
+ lab[sep_bud_label_mask] = sep_bud_label[sep_bud_label_mask]
max_i += 1
success = True
return lab, success, splittedIDs
@@ -3194,53 +3224,119 @@ def insert_missing_objects(
return segm_dst
-def process_lab(task):
- i, lab = task
- # Assuming this function processes each lab independently
- data_dict = {}
- rp = skimage.measure.regionprops(lab)
- IDs = [obj.label for obj in rp]
- data_dict['IDs'] = IDs
- data_dict['regionprops'] = rp
- data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)}
+### out of date
+# def process_lab(task):
+# i, lab = task
+# # Assuming this function processes each lab independently
+# data_dict = {}
+# rp = skimage.measure.regionprops(lab)
+# IDs = [obj.label for obj in rp]
+# data_dict['IDs'] = IDs
+# data_dict['regionprops'] = rp
+# data_dict['IDs_idxs'] = {ID: idx for idx, ID in enumerate(IDs)}
+
+# return i, data_dict, IDs # Return index, data_dict, and IDs
+
+# def parallel_count_objects(posData, logger_func):
+# benchmark = True
+# #futile attempt to use multiprocessing to speed things up
+# logger_func('Counting total number of segmented objects...')
+
+# allIDs = set()
+# seg_data = posData.segm_data
+
+# # Initialize empty data dictionary to avoid recalculating each time
+# tasks = [(i, lab) for i, lab in enumerate(seg_data)]
+
+# if benchmark:
+# t0 = time.perf_counter()
+# # Process in batches to optimize memory usage and control parallelism
+# with ThreadPoolExecutor() as executor:
+# futures = [executor.submit(process_lab, task) for task in tasks]
+
+# # Process results as they are completed
+# for future in tqdm(as_completed(futures), total=len(futures), ncols=100):
+# i, data_dict, IDs = future.result()
+# posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable
+# posData.allData_li[i]['IDs'] = data_dict['IDs']
+# posData.allData_li[i]['regionprops'] = data_dict['regionprops']
+# posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs']
+# allIDs.update(IDs)
- return i, data_dict, IDs # Return index, data_dict, and IDs
+# if benchmark:
+# t1 = time.perf_counter()
+# logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms')
-def parallel_count_objects(posData, logger_func):
- benchmark = True
- #futile attempt to use multiprocessing to speed things up
- logger_func('Counting total number of segmented objects...')
+# return allIDs, posData
+
+def check_file_time_proximity(file1, file2, max_seconds=300, logger_func=print):
+ if not os.path.isfile(file1):
+ return False
- allIDs = set()
- seg_data = posData.segm_data
+ if not os.path.isfile(file2):
+ return False
+
+ mtime1 = os.path.getmtime(file1)
+ mtime2 = os.path.getmtime(file2)
- # Initialize empty data dictionary to avoid recalculating each time
- tasks = [(i, lab) for i, lab in enumerate(seg_data)]
+ if abs(mtime1 - mtime2) <= max_seconds:
+ return True
+ else:
+ logger_func(f'Warning: The files "{file1}" and "{file2}" were not saved within {max_seconds} seconds of each other.')
+ return False
+
+def verify_acdc_df_segm(posData: 'load.loadData', logger_func=print):
+ if posData.segmMetadata is None:
+ return None
+ segm_info = posData.segmMetadata[os.path.basename(posData.segm_npz_path)]
+ imgs_folder = posData.images_path
+ csv_name = segm_info['acdc_df_segm'] if 'acdc_df_segm' in segm_info else None
+ if csv_name is None:
+ return None
+ csv_filepath = os.path.join(imgs_folder, csv_name)
+
+ # verify that that both files exist and are within the allowed time proximity
+ success = check_file_time_proximity(
+ posData.segm_npz_path, csv_filepath, max_seconds=120, logger_func=logger_func
+ )
+ if not success:
+ return None
+
+ return csv_filepath
- if benchmark:
- t0 = time.perf_counter()
- # Process in batches to optimize memory usage and control parallelism
- with ThreadPoolExecutor() as executor:
- futures = [executor.submit(process_lab, task) for task in tasks]
+def verify_add_data_segm_proximity(posData: 'load.loadData', logger_func=print):
+ segm_path = posData.segm_npz_path
+ segm_filename = os.path.basename(segm_path).replace('.npz', '')
+ add_data_folder = os.path.join(posData.images_path, segm_filename)
+
+ centroids_path = os.path.join(add_data_folder, 'centroids.pkl')
+ # IDs_path = os.path.join(add_data_folder, 'IDs.pkl')
+ centroids_IDs_exact_path = os.path.join(add_data_folder, 'centroids_IDs_exact.pkl')
+ # ID_to_idx_path = os.path.join(add_data_folder, 'ID_to_idx.pkl')
+
+ ok = [True] * 2
+ for idx, file in enumerate([centroids_path, centroids_IDs_exact_path]):
+ success = check_file_time_proximity(
+ segm_path, file, max_seconds=120, logger_func=logger_func
+ )
+ if not success:
+ ok[idx] = False
+
+ return {
+ 'centroids': centroids_path if ok[0] else None,
+ # 'IDs': IDs_path if ok[1] else None,
+ 'centroids_IDs_exact': centroids_IDs_exact_path if ok[1] else None,
+ # 'ID_to_idx': ID_to_idx_path if ok[3] else None,
+ }
- # Process results as they are completed
- for future in tqdm(as_completed(futures), total=len(futures), ncols=100):
- i, data_dict, IDs = future.result()
- posData.allData_li[i] = myutils.get_empty_stored_data_dict() # or directly assign if it's mutable
- posData.allData_li[i]['IDs'] = data_dict['IDs']
- posData.allData_li[i]['regionprops'] = data_dict['regionprops']
- posData.allData_li[i]['IDs_idxs'] = data_dict['IDs_idxs']
- allIDs.update(IDs)
- if benchmark:
- t1 = time.perf_counter()
- logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms')
-
- return allIDs, posData
-
-def count_objects(posData, logger_func):
- benchmark = False
-
+# WARNING: this function has been attempted to be optimized by
+# parallelization, loading data from last session
+# The main bottleneck seams to be the rp creation (not even for example getting the IDs or centorids)
+# Total time spend optimising here
+# >5 hrs
+# please update this if you try to optimize again
+def count_objects_and_init_rps(posData: 'load.loadData', logger_func=print):
allIDs = set()
segm_data = posData.segm_data
@@ -3250,23 +3346,14 @@ def count_objects(posData, logger_func):
logger_func('Counting total number of segmented objects...')
pbar = tqdm(total=len(segm_data), ncols=100)
- if benchmark:
- t0 = time.perf_counter()
for i, lab in enumerate(segm_data):
posData.allData_li[i] = myutils.get_empty_stored_data_dict()
- rp = skimage.measure.regionprops(lab)
- IDs = [obj.label for obj in rp]
- posData.allData_li[i]['IDs'] = IDs
+ rp = regionprops.acdcRegionprops(lab)
+ IDs = rp.IDs_set
posData.allData_li[i]['regionprops'] = rp
- posData.allData_li[i]['IDs_idxs'] = { # IDs_idxs[obj.label] = idx
- ID: idx for idx, ID in enumerate(IDs)
- }
allIDs.update(IDs)
pbar.update()
pbar.close()
- if benchmark:
- t1 = time.perf_counter()
- logger_func(f'Counting objects took {(t1 - t0)*1000:.2f} ms')
return allIDs, posData
def fix_sparse_directML(verbose=True):
diff --git a/cellacdc/dataPrep.py b/cellacdc/dataPrep.py
index 10b558445..9856e7f4b 100755
--- a/cellacdc/dataPrep.py
+++ b/cellacdc/dataPrep.py
@@ -44,6 +44,7 @@
from . import urls
from . import io
from .help import about
+from . import fonts
if os.name == 'nt':
try:
@@ -373,7 +374,7 @@ def gui_createToolBars(self):
navigateToolbar.addAction(self.interpAction)
self.ROIshapeComboBox = QComboBox()
- self.ROIshapeComboBox.setFont(apps.font)
+ self.ROIshapeComboBox.setFont(fonts.font)
self.ROIshapeComboBox.SizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToContents)
self.ROIshapeComboBox.addItems([' 256x256 '])
ROIshapeLabel = QLabel(html_utils.paragraph(
diff --git a/cellacdc/debugutils.py b/cellacdc/debugutils.py
index b55d0eef1..40b94f717 100644
--- a/cellacdc/debugutils.py
+++ b/cellacdc/debugutils.py
@@ -1,9 +1,101 @@
import inspect, os, datetime, sys, traceback
+import atexit
+import linecache
+from collections import defaultdict
from . import cellacdc_path, myutils
import gc
import psutil
+import time
+import functools
+
+_LINE_BENCHMARK_TRACE_LIMIT = 10000
+
+_LINE_BENCHMARK_STATS = defaultdict(
+ lambda: {
+ 'count': 0,
+ 'traced_count': 0,
+ 'untracked_count': 0,
+ 'total_time': 0.0,
+ 'min_time': float('inf'),
+ 'max_time': 0.0,
+ 'filename': None,
+ 'line_stats': defaultdict(
+ lambda: {
+ 'count': 0,
+ 'total_time': 0.0,
+ 'min_time': float('inf'),
+ 'max_time': 0.0,
+ }
+ ),
+ }
+)
+
+def _get_benchmark_line_snippet(filename, lineno, max_chars=30):
+ if lineno == 'return':
+ return ''
+ if not filename:
+ return ''
+
+ line = linecache.getline(filename, lineno).strip()
+ if not line:
+ return ''
+
+ if len(line) <= max_chars:
+ # fill up to max_chars for better alignment
+ line = line.ljust(max_chars)
+ return line
+ return f'{line[:max_chars-3]}...'
+
+def _print_line_benchmark_session_stats():
+ if not _LINE_BENCHMARK_STATS:
+ return
+
+ print('\nLine benchmark session summary:')
+ for func_name, stats in sorted(_LINE_BENCHMARK_STATS.items()):
+ total_count = stats['count']
+ traced_count = stats['traced_count']
+ untracked_count = stats['untracked_count']
+ if total_count == 0:
+ continue
+
+ if traced_count:
+ mean_time = stats['total_time'] / traced_count
+ print(
+ f'{func_name}: n={total_count} | '
+ f'traced={traced_count} | '
+ f'untracked={untracked_count} | '
+ f'mean={mean_time*1000:.3f} ms | '
+ f'min={stats["min_time"]*1000:.3f} ms | '
+ f'max={stats["max_time"]*1000:.3f} ms | '
+ f'total={stats["total_time"]*1000:.3f} ms'
+ )
+ else:
+ print(
+ f'{func_name}: n={total_count} | '
+ f'traced=0 | '
+ f'untracked={untracked_count}'
+ )
+
+ line_stats = stats['line_stats']
+ top_lines = sorted(
+ line_stats.items(),
+ key=lambda item: item[1]['total_time'],
+ reverse=True
+ )[:10]
+ filename = stats['filename']
+ for (start_line, end_line), line_stat in top_lines:
+ line_mean = line_stat['total_time'] / line_stat['count']
+ line_snippet = _get_benchmark_line_snippet(filename, start_line)
+ print(
+ f' {line_snippet:<30} {start_line} -> {end_line}: '
+ f'n={line_stat["count"]} | '
+ f'mean={line_mean*1000:.3f} ms | '
+ f'total={line_stat["total_time"]*1000:.3f} ms'
+ )
+
+atexit.register(_print_line_benchmark_session_stats)
def showRefGraph(object_str:str, debug:bool=True):
"""Save a reference graph of the given object type.
@@ -206,3 +298,140 @@ def print_largest_classes(package_prefix="cellacdc", top_n=10, max_instances=100
# Example usage:
# print_largest_classes("cellacdc", top_n=10)
+
+# Return a benchmark checkpoint with caller line information.
+def return_timer_and_line(benchmarking=True):
+ if not benchmarking:
+ return None
+ timestamp = time.perf_counter()
+ line = inspect.currentframe().f_back.f_lineno # is super fast!
+ return (timestamp, line)
+
+def print_benchmarks(timers, benchmarking=True):
+ if not benchmarking:
+ return
+ checkpoints = [timer for timer in timers if timer is not None]
+ if len(checkpoints) < 2:
+ return
+
+ print("Benchmarks:")
+ for (start_time, start_line), (end_time, end_line) in zip(
+ checkpoints, checkpoints[1:]
+ ):
+ duration = end_time - start_time
+ print(
+ f"Line {start_line} -> {end_line}: "
+ f"{duration:.6f} seconds"
+ )
+
+ total_duration = checkpoints[-1][0] - checkpoints[0][0]
+ print(f"Total: {total_duration:.6f} seconds")
+
+def line_benchmark(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ stats_key = f'{func.__module__}.{func.__qualname__}'
+ stats = _LINE_BENCHMARK_STATS[stats_key]
+ stats['count'] += 1
+
+ if stats['traced_count'] >= _LINE_BENCHMARK_TRACE_LIMIT:
+ stats['untracked_count'] += 1
+ return func(*args, **kwargs)
+
+ target_code = func.__code__
+ filename = target_code.co_filename
+ checkpoints = []
+ last_time = None
+ last_line = None
+
+ def tracer(frame, event, arg):
+ nonlocal last_time, last_line
+
+ if frame.f_code is not target_code:
+ return tracer
+
+ now = time.perf_counter()
+
+ if event == "call":
+ last_time = now
+ last_line = frame.f_lineno
+ return tracer
+
+ if event == "line":
+ if last_time is not None and last_line is not None:
+ checkpoints.append((last_line, frame.f_lineno, now - last_time))
+ last_time = now
+ last_line = frame.f_lineno
+ return tracer
+
+ if event == "return":
+ if last_time is not None and last_line is not None:
+ checkpoints.append((last_line, "return", now - last_time))
+ return tracer
+
+ return tracer
+
+ old_trace = sys.gettrace()
+ sys.settrace(tracer)
+ try:
+ result = func(*args, **kwargs)
+ finally:
+ sys.settrace(old_trace)
+
+ total = sum(dt for _, _, dt in checkpoints)
+ stats['traced_count'] += 1
+ stats['total_time'] += total
+ stats['min_time'] = min(stats['min_time'], total)
+ stats['max_time'] = max(stats['max_time'], total)
+ stats['filename'] = filename
+
+ for start_line, end_line, dt in checkpoints:
+ line_stat = stats['line_stats'][(start_line, end_line)]
+ line_stat['count'] += 1
+ line_stat['total_time'] += dt
+ line_stat['min_time'] = min(line_stat['min_time'], dt)
+ line_stat['max_time'] = max(line_stat['max_time'], dt)
+
+ return result
+
+ return wrapper
+
+def check_unused_methods(node_name):
+ import ast
+ from pathlib import Path
+ from collections import Counter
+
+ file_path = Path("cellacdc/gui.py")
+ src = file_path.read_text(encoding="utf-8")
+ tree = ast.parse(src)
+
+ gui_cls = None
+ for node in tree.body:
+ if isinstance(node, ast.ClassDef) and node.name == node_name:
+ gui_cls = node
+ break
+
+ if gui_cls is None:
+ raise RuntimeError(f"{node_name} class not found")
+
+ methods = []
+ refs = Counter()
+
+ for node in gui_cls.body:
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ methods.append(node.name)
+
+ for n in ast.walk(node):
+ # Count any self.method reference (calls, signal connections, passing callbacks, etc.)
+ if isinstance(n, ast.Attribute) and isinstance(n.value, ast.Name) and n.value.id == "self":
+ refs[n.attr] += 1
+
+ # Also count guiWin.method(...) direct class-qualified calls inside class
+ if isinstance(n, ast.Attribute) and isinstance(n.value, ast.Name) and n.value.id == node_name:
+ refs[n.attr] += 1
+
+ unused_inside_guiwin = [m for m in methods if refs[m] == 0]
+ print(f"Total methods: {len(methods)}")
+ print(f"Potentially unused inside {node_name}: {len(unused_inside_guiwin)}")
+ for m in sorted(unused_inside_guiwin):
+ print(m)
\ No newline at end of file
diff --git a/cellacdc/docs/source/installation.rst b/cellacdc/docs/source/installation.rst
index 901bde231..d5fc9fa2d 100644
--- a/cellacdc/docs/source/installation.rst
+++ b/cellacdc/docs/source/installation.rst
@@ -485,9 +485,26 @@ If you want to try out experimental features (and, if you have time, maybe repor
Updating Cell-ACDC installed from source
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To update Cell-ACDC installed from source, open a terminal window, navigate to the
Cell-ACDC folder with the command ``cd Cell_ACDC`` and run ``git pull``.
-Since you installed with the ``-e`` flag, pulling with ``git`` is enough.
\ No newline at end of file
+Since you installed with the ``-e`` flag, pulling with ``git`` is enough.
+
+
+
+Compile Cython extensions
+-------------------------
+
+Some of the functions in Cell-ACDC are implemented in Cython, which allows them
+to run much faster. However, Cython functions need to be compiled before they
+can be used. We provide pre-compiled versions of the Cython extensions for
+Windows, macOS and Linux, but if you want to compile them yourself,
+you can do so by activating your ``acdc`` environment, installing
+Cython and setuptools with the command ``pip install Cython setuptools``,
+and running the following command in the terminal from the Cell_ACDC folder:
+
+.. code-block::
+
+ python precompile_functions.py build_ext --inplace --build-temp build/temp
\ No newline at end of file
diff --git a/cellacdc/docs/source/tooltips.rst b/cellacdc/docs/source/tooltips.rst
index 8a0404896..9857908af 100644
--- a/cellacdc/docs/source/tooltips.rst
+++ b/cellacdc/docs/source/tooltips.rst
@@ -447,8 +447,8 @@ Edit tools: Segmentation and tracking
* Note: right-click on a background ROI to remove it.
* HELP: Use this function if you need to set the background level specific for each object. Cell-ACDC will save the metrics `amount`, `concentration` and `corrected_mean` where the background correction will be performed by subtracting the mean of the signal in the background ROI (for each object).
* **Delete everything outside segmented areas (** |delObjsOutSegmMaskAction| **):** Select a segmentation file and delete everything outside segmented area.
-* **Hull contour (** |hullContToolButton| **"K"):** Right-click on a cell to replace it with its hull contour. Use it to fill cracks and holes.
-* **Fill holes (** |fillHolesToolButton| **"F"):** Right-click on a cell to fill holes.
+* **Hull contour (** |hullContToolButton| **"K"):** Right-click on a cell to replace it with its hull contour. Use it to fill cracks and holes. When working with 3D segmentation masks, the default behaviour is to replace only the viewed z-slice. To replace the entire object in 3D, hold "Shift" while right-clicking.
+* **Fill holes (** |fillHolesToolButton| **"F"):** Right-click on a cell to fill holes. When working with 3D segmentation masks, the default behaviour is to fill only the viewed z-slice. To fill the entire object in 3D, hold "Shift" while right-clicking.
* **Move object mask (** |moveLabelToolButton| **"P"):** Right-click drag and drop a labels to move it around.
* **Expand/Shrink object mask (** |expandLabelToolButton| **"E"):** Leave mouse cursor on the label you want to expand/shrink and press arrow up/down on the keyboard to expand/shrink the mask.
* **Edit ID (** |editIDbutton| **"N"):** Manually change ID of a cell by right-clicking on cell. When working with 3D segmentation masks, the default behaviour is to edit the ID in all z-slices. To edit the ID only on the viewed z-slice, hold "Shift" while right-clicking.
diff --git a/cellacdc/fonts.py b/cellacdc/fonts.py
new file mode 100644
index 000000000..31219da52
--- /dev/null
+++ b/cellacdc/fonts.py
@@ -0,0 +1,14 @@
+from . import GUI_INSTALLED
+
+if GUI_INSTALLED:
+ from qtpy.QtGui import QFont
+
+ font = QFont()
+ font.setPixelSize(12)
+ italicFont = QFont()
+ italicFont.setPixelSize(12)
+ italicFont.setItalic(True)
+
+else:
+ font = None
+ italicFont = None
\ No newline at end of file
diff --git a/cellacdc/gui.py b/cellacdc/gui.py
index 7247f462e..e09b86bbc 100755
--- a/cellacdc/gui.py
+++ b/cellacdc/gui.py
@@ -90,12 +90,13 @@
from . import is_mac
from .trackers.CellACDC import CellACDC_tracker
from .cca_functions import _calc_rot_vol
-from .myutils import exec_time, setupLogger, ArgSpec
+from .myutils import setupLogger, ArgSpec
from .help import welcome, about
from .trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import (
normal_division_lineage_tree)#, reorg_sister_cells_for_export)
from . import debugutils
-
+from . import regionprops
+from . import exec_time
from .plot import imshow
from . import gui_utils
@@ -114,6 +115,11 @@
GREEN_HEX = _palettes.green()
+RP_OPT_NUM_CELLS_MIN = 30 # th for trying to do local updates to regionprops, rp becomes slow for high num of cells
+RP_OPT_PERC_CUTOUT_MAX = 0.3 # th for trying to do local updates to regionprops,
+ # if region which we have to update is too large too
+ # many cells are probably inside and its not worth
+ # local updating (since we actually need to call RP twice!)
custom_annot_path = os.path.join(settings_folderpath, 'custom_annotations.json')
shortcut_filepath = os.path.join(settings_folderpath, 'shortcuts.ini')
@@ -537,7 +543,7 @@ def initGlobalAttr(self):
]
self.lin_tree_df_colnames = self.lin_tree_df_int_cols + self.lin_tree_df_bool_col + self.lin_tree_col_checks
- self.SegForLostIDsSettings = {}
+ self.SegForLostIDsSettings = {}
def setWindowIcon(self, icon=None):
if icon is None:
@@ -1344,7 +1350,6 @@ def gui_createToolBars(self):
self.hullContToolButton.action = editToolBar.addWidget(self.hullContToolButton)
self.checkableButtons.append(self.hullContToolButton)
self.checkableQButtonsGroup.addButton(self.hullContToolButton)
- self.functionsNotTested3D.append(self.hullContToolButton)
self.widgetsWithShortcut['Hull contour'] = self.hullContToolButton
self.fillHolesToolButton = QToolButton(self)
@@ -1356,7 +1361,6 @@ def gui_createToolBars(self):
)
self.checkableButtons.append(self.fillHolesToolButton)
self.checkableQButtonsGroup.addButton(self.fillHolesToolButton)
- self.functionsNotTested3D.append(self.fillHolesToolButton)
self.widgetsWithShortcut['Fill holes'] = self.fillHolesToolButton
self.moveLabelToolButton = QToolButton(self)
@@ -2034,6 +2038,7 @@ def gui_createControlsToolbar(self):
brushEraserToolBar.addWidget(QLabel(' '))
self.brushAutoFillCheckbox = QCheckBox('Auto-fill holes')
+ self.brushAutoFillCheckbox.setTristate(False)
self.brushAutoFillAction = brushEraserToolBar.addWidget(
self.brushAutoFillCheckbox
)
@@ -2708,7 +2713,6 @@ def gui_createActions(self):
# Edit actions
models = myutils.get_list_of_models()
- models = [*models, 'local_seg'] # Add local_seg for SegForLostIDsAction
self.segmActions = []
self.modelNames = []
self.acdcSegment_li = []
@@ -2765,7 +2769,7 @@ def gui_createActions(self):
'Track current frame with real-time tracker...', self
)
self.repeatTrackingMenuAction.setDisabled(True)
- self.repeatTrackingMenuAction.setShortcut('Shift+T')
+ self.repeatTrackingMenuAction.setShortcut('Ctrl+T')
self.repeatTrackingVideoAction = QAction(
'Select a tracker and track multiple frames...', self
@@ -3777,8 +3781,7 @@ def gui_createQuickSettingsWidgets(self):
def showAllContoursToggled(self):
if not self.isDataLoaded:
return
-
- self.computeAllContours()
+
self.updateAllImages()
def gui_createImg1Widgets(self):
@@ -4692,7 +4695,8 @@ def _gui_createGraphicsItems(self):
posData = self.data[self.pos_i]
- allIDs, posData = core.count_objects(posData, self.logger.info)
+ allIDs, posData = core.count_objects_and_init_rps(
+ posData, self.logger.info)
self.highLowResAction.setChecked(True)
numItems = len(allIDs)
@@ -4947,6 +4951,23 @@ def gui_initImg1BottomWidgets(self):
self.zSliceSpinbox.hide()
self.SizeZlabel.hide()
+ def rpCurr2D(self, frame_i=None, slice_i=None, depth_axis=None):
+ posData = self.data[self.pos_i]
+ if frame_i is None:
+ rp = posData.rp
+ else:
+ rp = posData.allData_li[frame_i]['regionprops']
+ if not self.isSegm3D:
+ return rp
+
+ if slice_i is None:
+ slice_i = self.zSliceScrollBar.sliderPosition()
+ if depth_axis is None:
+ depth_axis = self.switchPlaneCombobox.depthAxes()
+
+ rp2D = rp.get_slice_rp(slice_i, slicing=depth_axis)
+ return rp2D
+
@exception_handler
def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
modifiers = QGuiApplication.keyboardModifiers()
@@ -4986,9 +5007,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
- Y, X = self.get_2Dlab(posData.lab).shape
+ Y, X = self.get_2Dlab(posData.lab, force_z=False).shape
if xdata >= 0 and xdata < X and ydata >= 0 and ydata < Y:
- ID = self.get_2Dlab(posData.lab)[ydata, xdata]
+ ID = self.get_2Dlab(posData.lab, force_z=False)[ydata, xdata]
else:
return
@@ -5135,7 +5156,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
return
else:
ID = sepID_prompt.EntryID
- y, x = posData.rp[posData.IDs_idxs[ID]].centroid[-2:]
+
+ centroid = posData.rp.get_centroid(ID)
+ y, x = self.getObjCentroid(centroid)
xdata, ydata = int(x), int(y)
# Store undo state before modifying stuff
@@ -5152,7 +5175,7 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
# self.set_2Dlab(lab2D)
elif not shift:
result = core.split_along_convexity_defects(
- ID, self.get_2Dlab(posData.lab), max_ID
+ ID, self.get_2Dlab(posData.lab), max_ID, rp=posData.rp
)
lab2D, success, splittedIDs = result
self.set_2Dlab(lab2D)
@@ -5190,7 +5213,9 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
self.storeManualSeparateDrawMode(manualSep.drawMode)
# Update data (rp, etc)
- self.update_rp()
+ bbox = self.update_rp_get_bbox(use_bbox=True, specific_IDs=ID) # use old ID to get bbox
+ specific_IDs = list(splittedIDs) + [ID]
+ self.update_rp(specific_IDs=specific_IDs, preloaded_bbox=bbox)
# Repeat tracking
self.trackSubsetIDs(splittedIDs)
@@ -5233,13 +5258,24 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
if ID in posData.lab:
# Store undo state before modifying stuff
self.storeUndoRedoStates(False)
- obj_idx = posData.IDs.index(ID)
- obj = posData.rp[obj_idx]
- objMask = self.getObjImage(obj.image, obj.bbox)
- localFill = scipy.ndimage.binary_fill_holes(objMask)
- posData.lab[self.getObjSlice(obj.slice)][localFill] = ID
+ if not shift and self.isSegm3D:
+ rp2D = self.rpCurr2D()
+ obj = rp2D.get_obj_from_ID(ID)
+ else: # shift hold or 2D from the getgo
+ obj = posData.rp.get_obj_from_ID(ID)
+
+ localFill = scipy.ndimage.binary_fill_holes(obj.image)
+
+ if not shift and self.isSegm3D:
+ curr_z = self.zSliceScrollBar.sliderPosition()
+ posData.lab[curr_z][obj.slice][localFill] = ID
+ else:
+ posData.lab[obj.slice][localFill] = ID
- self.update_rp()
+ # here it is impossible that hole filling overwrites an ID which
+ # otuches border
+
+ self.update_rp(use_bbox=True, specific_IDs=ID)
self.updateAllImages()
if not self.fillHolesToolButton.findChild(QAction).isChecked():
@@ -5272,13 +5308,24 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
if ID in posData.lab:
# Store undo state before modifying stuff
self.storeUndoRedoStates(False)
- obj_idx = posData.IDs.index(ID)
- obj = posData.rp[obj_idx]
- objMask = self.getObjImage(obj.image, obj.bbox)
- localHull = skimage.morphology.convex_hull_image(objMask)
- posData.lab[self.getObjSlice(obj.slice)][localHull] = ID
+ if not shift and self.isSegm3D:
+ rp2D = self.rpCurr2D()
+ obj = rp2D.get_obj_from_ID(ID)
+ else:
+ obj = posData.rp.get_obj_from_ID(ID)
+
+ localHull = skimage.morphology.convex_hull_image(obj.image)
+ if not shift and self.isSegm3D:
+ curr_z = self.zSliceScrollBar.sliderPosition()
+ hull_lab = posData.lab[curr_z][obj.slice]
+ else:
+ hull_lab = posData.lab[obj.slice]
- self.update_rp()
+ IDs_overwritten = np.unique(hull_lab[localHull])
+ IDs_overwritten = IDs_overwritten[IDs_overwritten != 0]
+ hull_lab[localHull] = ID
+
+ self.update_rp(use_bbox=True, specific_IDs=IDs_overwritten)
self.updateAllImages()
if not self.hullContToolButton.findChild(QAction).isChecked():
@@ -5292,30 +5339,6 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
x, y = event.pos().x(), event.pos().y()
self.startMovingLabel(x, y)
- # Fill holes
- elif right_click and self.fillHolesToolButton.isChecked():
- x, y = event.pos().x(), event.pos().y()
- xdata, ydata = int(x), int(y)
- ID = self.get_2Dlab(posData.lab)[ydata, xdata]
- if ID == 0:
- nearest_ID = core.nearest_nonzero_2D(
- self.get_2Dlab(posData.lab), y, x
- )
- clickedBkgrID = apps.QLineEditDialog(
- title='Clicked on background',
- msg='You clicked on the background.\n'
- 'Enter here the ID that you want to '
- 'fill the holes of',
- parent=self, allowedValues=posData.IDs,
- defaultTxt=str(nearest_ID),
- isInteger=True
- )
- clickedBkgrID.exec_()
- if clickedBkgrID.cancel:
- return
- else:
- ID = clickedBkgrID.EntryID
-
# Merge IDs
elif right_click and self.mergeIDsButton.isChecked():
x, y = event.pos().x(), event.pos().y()
@@ -5344,9 +5367,8 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
self.storeUndoRedoStates(False)
self.firstID = ID
- obj_idx = posData.IDs_idxs[ID]
- obj = posData.rp[obj_idx]
- yc, xc = self.getObjCentroid(obj.centroid)
+ centroid = posData.rp.get_centroid(ID) # maybe use 2D centroid here?
+ yc, xc = self.getObjCentroid(centroid)
self.clickObjYc, self.clickObjXc = int(yc), int(xc)
# Edit ID
@@ -5373,8 +5395,8 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
else:
ID = editID_prompt.EntryID
- obj_idx = posData.IDs_idxs[ID]
- y, x = posData.rp[obj_idx].centroid[-2:]
+ centroid = posData.rp.get_centroid(ID, exact=True)
+ y, x = self.getObjCentroid(centroid)
xdata, ydata = int(x), int(y)
posData.disableAutoActivateViewerWindow = True
@@ -5402,7 +5424,7 @@ def gui_mousePressEventImg2(self, event: QGraphicsSceneMouseEvent):
return
if editID.assignNewID:
- self.assignNewIDfromClickedID(ID, event)
+ self.assignNewIDfromClickedID(ID, event, shift=shift)
return
if not self.doNotAskAgainExistingID:
@@ -5646,7 +5668,7 @@ def expandLabel(self, dilation=True):
ID = self.hoverLabelID
- obj = posData.rp[posData.IDs.index(ID)]
+ obj = posData.rp.get_obj_from_ID(ID)
if reinitExpandingLab:
# Store undo state before modifying stuff
@@ -5679,7 +5701,9 @@ def expandLabel(self, dilation=True):
expandedLab[self.currentLab2D>0] = 0
# Get coords of the dilated/eroded object
- expandedObj = skimage.measure.regionprops(expandedLab)[0]
+ expandedObj = regionprops.acdcRegionprops(
+ expandedLab, precache_centroids=False)[0]
+ expandedObj_bbox = expandedObj.bbox
expandedObjCoords = (expandedObj.coords[:,-2], expandedObj.coords[:,-1])
# Add the dilated/erored object
@@ -5689,7 +5713,10 @@ def expandLabel(self, dilation=True):
self.set_2Dlab(lab_2D)
self.currentLab2D = lab_2D
- self.update_rp()
+ preloaded_bbox = self.update_rp_get_bbox(custom_bbox=expandedObj_bbox)
+ self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=ID)
+ # we dont draw over other IDs so this is rare case where its fine
+ # to just have tight bbox and specific_IDs=ID
if self.labelsGrad.showLabelsImgAction.isChecked():
self.img2.setImage(img=self.currentLab2D, autoLevels=False)
@@ -5712,7 +5739,7 @@ def startMovingLabel(self, xPos, yPos):
self.searchedIDitemLeft.setData([], [])
self.movingID = ID
self.prevMovePos = (xdata, ydata)
- movingObj = posData.rp[posData.IDs.index(ID)]
+ movingObj = posData.rp.get_obj_from_ID(ID)
self.movingObjCoords = movingObj.coords.copy()
yy, xx = movingObj.coords[:,-2], movingObj.coords[:,-1]
self.currentLab2D[yy, xx] = 0
@@ -5788,7 +5815,7 @@ def gui_mouseDragEventImg1(self, event):
return
posData = self.data[self.pos_i]
- Y, X = self.get_2Dlab(posData.lab).shape
+ Y, X = self.get_2Dlab(posData.lab, force_z=False).shape
xdata, ydata = int(x), int(y)
if not myutils.is_in_bounds(xdata, ydata, X, Y):
return
@@ -5798,7 +5825,7 @@ def gui_mouseDragEventImg1(self, event):
# Brush dragging mouse --> keep brushing
elif self.isMouseDragImg1 and self.brushButton.isChecked():
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
# t1 = time.perf_counter()
@@ -5835,7 +5862,7 @@ def gui_mouseDragEventImg1(self, event):
# t5 = time.perf_counter()
- lab2D = self.get_2Dlab(posData.lab)
+ lab2D = self.get_2Dlab(posData.lab, force_z=False)
brushMask = np.logical_and(
lab2D[diskSlice] == posData.brushID, diskMask
)
@@ -5860,7 +5887,7 @@ def gui_mouseDragEventImg1(self, event):
# Eraser dragging mouse --> keep erasing
elif self.isMouseDragImg1 and self.eraserButton.isChecked():
posData = self.data[self.pos_i]
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
rrPoly, ccPoly = self.getPolygonBrush((y, x), Y, X)
ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata)
@@ -5948,18 +5975,28 @@ def gui_mouseDragEventImg1(self, event):
self.zoomRectItem.setSize((w, h))
# @exec_time
- def fillHolesID(self, ID, sender='brush'):
+ def fillHolesID(self, ID, sender='brush', enabled=None):
posData = self.data[self.pos_i]
if sender == 'brush':
- if not self.brushAutoFillCheckbox.isChecked():
+ if enabled is None:
+ enabled = self.brushAutoFillCheckbox.isChecked()
+
+ if not enabled:
+ return False
+
+ if not self.brushButton.isChecked():
return False
- lab2D = self.get_2Dlab(posData.lab)
+ lab2D = self.get_2Dlab(posData.lab, force_z=False)
mask = lab2D == ID
filledMask = scipy.ndimage.binary_fill_holes(mask)
- lab2D[filledMask] = ID
+ newFilledMask = np.logical_and(filledMask, ~mask)
+ if not np.any(newFilledMask):
+ return False
- self.set_2Dlab(lab2D)
+ # Apply only newly filled pixels to avoid rewriting the full 3D
+ # stack when editing in projection mode.
+ self.applyBrushMask(newFilledMask, ID)
return True
return False
@@ -5984,11 +6021,9 @@ def highlightSearchedIDcheckBoxToggled(self, checked):
self.highlightedID = self.getHighlightedID()
if self.highlightedID == 0:
return
- objIdx = posData.IDs_idxs[self.highlightedID]
- obj_idx = posData.IDs_idxs.get(self.highlightedID)
- if obj_idx is None:
+ obj = posData.rp.get_obj_from_ID(self.highlightedID)
+ if obj is None:
return
- obj = posData.rp[objIdx]
self.goToZsliceSearchedID(obj)
def setHighlightID(self, doHighlight):
@@ -6008,13 +6043,12 @@ def propsWidgetIDvalueChanged(self, ID):
return
propsQGBox = self.guiTabControl.propsQGBox
- obj_idx = posData.IDs_idxs.get(ID)
- if obj_idx is None:
+ obj = posData.rp.get_obj_from_ID(ID)
+ if obj is None:
s = f'Object ID {int(ID):d} does not exist'
propsQGBox.notExistingIDLabel.setText(s)
return
- obj = posData.rp[obj_idx]
self.goToZsliceSearchedID(obj)
self.updatePropsWidget(int(ID))
@@ -6039,7 +6073,7 @@ def updatePropsWidget(self, ID, fromHover=False):
return
if posData.rp is None:
- self.update_rp()
+ self.update_rp() # IDK when can this happen?
if not posData.IDs:
# empty segmentation mask
@@ -6051,8 +6085,8 @@ def updatePropsWidget(self, ID, fromHover=False):
propsQGBox = self.guiTabControl.propsQGBox
- obj_idx = posData.IDs_idxs.get(ID)
- if obj_idx is None:
+ obj = posData.rp.get_obj_from_ID(ID)
+ if obj is None:
s = f'Object ID {int(ID):d} does not exist'
propsQGBox.notExistingIDLabel.setText(s)
return
@@ -6068,8 +6102,6 @@ def updatePropsWidget(self, ID, fromHover=False):
if doHighlight:
self.highlightSearchedID(ID)
- obj = posData.rp[obj_idx]
-
if self.isSegm3D:
if self.zProjComboBox.currentText() == 'single z-slice':
local_z = self.z_lab() - obj.bbox[0]
@@ -6355,9 +6387,8 @@ def drawTempMothBudLine(self, event, posData):
if ID == 0:
self.BudMothTempLine.setData([x1, x2], [y1, y2])
else:
- obj_idx = posData.IDs_idxs[ID]
- obj = posData.rp[obj_idx]
- y2, x2 = self.getObjCentroid(obj.centroid)
+ centroid = posData.rp.get_centroid(ID)
+ y2, x2 = self.getObjCentroid(centroid)
self.BudMothTempLine.setData([x1, x2], [y1, y2])
def drawTempMergeObjsLine(self, event, posData, modifiers):
@@ -6370,9 +6401,8 @@ def drawTempMergeObjsLine(self, event, posData, modifiers):
y1, x1 = self.clickObjYc, self.clickObjXc
ID = self.get_2Dlab(posData.lab)[ydata, xdata]
if ID != 0:
- obj_idx = posData.IDs_idxs[ID]
- obj = posData.rp[obj_idx]
- y2, x2 = self.getObjCentroid(obj.centroid)
+ centroid = posData.rp.get_centroid(ID)
+ y2, x2 = self.getObjCentroid(centroid)
if modifier and ID > 0:
self.mergeObjsTempLine.addPoint(x2, y2)
@@ -6535,7 +6565,7 @@ def gui_setCursor(self, modifiers, event):
def warnAddingPointWithExistingId(self, point_id, table_endname=''):
posData = self.data[self.pos_i]
- if not point_id in posData.IDs_idxs:
+ if not point_id in posData.IDs:
return True
msg = widgets.myMessageBox(wrapText=False)
@@ -6695,7 +6725,7 @@ def gui_mouseDragEventImg2(self, event):
if mode == 'Viewer':
return
- Y, X = self.get_2Dlab(posData.lab).shape
+ Y, X = self.get_2Dlab(posData.lab, force_z=False).shape
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
if not myutils.is_in_bounds(xdata, ydata, X, Y):
@@ -6704,7 +6734,7 @@ def gui_mouseDragEventImg2(self, event):
# Eraser dragging mouse --> keep erasing
if self.isMouseDragImg2 and self.eraserButton.isChecked():
posData = self.data[self.pos_i]
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
Y, X = lab_2D.shape
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
@@ -6735,7 +6765,7 @@ def gui_mouseDragEventImg2(self, event):
# Brush paint dragging mouse --> keep painting
if self.isMouseDragImg2 and self.brushButton.isChecked():
posData = self.data[self.pos_i]
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
Y, X = lab_2D.shape
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
@@ -6775,7 +6805,7 @@ def gui_mouseReleaseEventImg2(self, event):
if mode == 'Viewer':
return
- Y, X = self.get_2Dlab(posData.lab).shape
+ Y, X = self.get_2Dlab(posData.lab, force_z=False).shape
try:
x, y = event.pos().x(), event.pos().y()
except Exception as e:
@@ -6792,7 +6822,7 @@ def gui_mouseReleaseEventImg2(self, event):
self.isMovingLabel = False
# Update data (rp, etc)
- self.update_rp()
+ self.update_rp() # IDK can I do optimization here?
# Repeat tracking
self.tracking(enforce=True, assign_unique_new_IDs=False)
@@ -6806,7 +6836,7 @@ def gui_mouseReleaseEventImg2(self, event):
elif self.mergeIDsButton.isChecked():
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
- lab2D = self.get_2Dlab(posData.lab)
+ lab2D = self.get_2Dlab(posData.lab, force_z=False)
ID = lab2D[ydata, xdata]
if ID == 0:
nearest_ID = core.nearest_nonzero_2D(
@@ -6826,31 +6856,33 @@ def gui_mouseReleaseEventImg2(self, event):
return
else:
ID = mergeID_prompt.EntryID
- obj_idx = posData.IDs_idxs[ID]
- obj = posData.rp[obj_idx]
- y2, x2 = self.getObjCentroid(obj.centroid)
- self.mergeObjsTempLine.addPoint(x2, y2)
+ centroid = posData.rp.get_centroid(ID)
+ ydata, xdata = self.getObjCentroid(centroid)
+ ydata, xdata = int(ydata), int(xdata)
xx, yy = self.mergeObjsTempLine.getData()
IDs_to_merge = lab2D[yy.astype(int), xx.astype(int)]
for ID in IDs_to_merge:
if ID == 0:
continue
- posData.lab[posData.lab==ID] = self.firstID
+ obj = posData.rp.get_obj_from_ID(ID)
+
+ posData.lab[obj.slice][obj.image] = self.firstID
self.mergeObjsTempLine.setData([], [])
self.clickObjYc, self.clickObjXc = None, None
-
- # Update data (rp, etc)
- self.update_rp()
-
+
+ bbox = self.update_rp_get_bbox(specific_IDs=IDs_to_merge,use_bbox=True) # use old IDs to get bbox
+ specific_IDs = list(IDs_to_merge) + [self.firstID]
+ self.update_rp(specific_IDs=specific_IDs,preloaded_bbox=bbox) # update with new IDs
ask_back_prop = True
if posData.frame_i == 0:
ask_back_prop = False
prev_IDs = []
else:
- prev_IDs = posData.allData_li[posData.frame_i-1]['IDs']
+ prev_IDs = (
+ posData.allData_li[posData.frame_i-1]['regionprops'].IDs)
if all(ID not in prev_IDs for ID in IDs_to_merge):
ask_back_prop = False
@@ -6889,7 +6921,7 @@ def gui_mouseReleaseEventImg1(self, event):
if mode == 'Viewer':
return
- Y, X = self.get_2Dlab(posData.lab).shape
+ Y, X = self.get_2Dlab(posData.lab, force_z=False).shape
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
if not myutils.is_in_bounds(xdata, ydata, X, Y):
@@ -6919,8 +6951,9 @@ def gui_mouseReleaseEventImg1(self, event):
if self.isRightClickDragImg1 and self.curvToolButton.isChecked():
self.isRightClickDragImg1 = False
try:
- self.curvToolSplineToObj(isRightClick=True)
- self.update_rp()
+ mask, returnID = self.curvToolSplineToObj(isRightClick=True)
+ if mask is not None:
+ self.update_rp() # how can I optimize this? I think not possible tbh
if self.autoIDcheckbox.isChecked():
self.trackManuallyAddedObject(posData.brushID, True)
if self.isSnapshot:
@@ -6940,13 +6973,18 @@ def gui_mouseReleaseEventImg1(self, event):
self.isMouseDragImg1 = False
self.clearTempBrushImage()
+
+ erasedIDs = [ID for ID in self.erasedIDs if ID != 0]
# Update data (rp, etc)
- self.update_rp()
+ self.update_rp(
+ use_curr_view=True,
+ specific_IDs=erasedIDs or None,
+ ) # only visible stuff can be deleted
doUpdateImages = self.checkWarnDeletedIDwithEraser()
- if doUpdateImages:
+ if not doUpdateImages:
self.updateAllImages()
# Brush button mouse release
@@ -6967,7 +7005,8 @@ def gui_mouseReleaseEventImg1(self, event):
posData.lab[self.flood_mask] = posData.brushID
# Update data (rp, etc)
- self.update_rp()
+ # only visible stuff can be added, plus doesnt draw over eixisting
+ self.update_rp(use_curr_view=True, specific_IDs=posData.brushID)
# Repeat tracking
self.trackManuallyAddedObject(posData.brushID, self.isNewID)
@@ -7048,7 +7087,7 @@ def gui_mouseReleaseEventImg1(self, event):
self.isMovingLabel = False
# Update data (rp, etc)
- self.update_rp()
+ self.update_rp(use_curr_view=True) # only visible stuff can be moved
# Repeat tracking
self.tracking(enforce=True, assign_unique_new_IDs=False)
@@ -7083,9 +7122,9 @@ def gui_mouseReleaseEventImg1(self, event):
return
else:
ID = mothID_prompt.EntryID
- obj_idx = posData.IDs.index(ID)
- y, x = posData.rp[obj_idx].centroid
- xdata, ydata = int(x), int(y)
+ centroid = posData.rp.get_centroid(ID)
+ ydata, xdata = self.getObjCentroid(centroid)
+ ydata, xdata = int(ydata), int(xdata)
if self.isSnapshot:
# Store undo state before modifying stuff
@@ -7116,11 +7155,9 @@ def gui_mouseReleaseEventImg1(self, event):
# on a mother
budID = self.get_2Dlab(posData.lab)[self.yClickBud, self.xClickBud]
new_mothID = self.get_2Dlab(posData.lab)[ydata, xdata]
- bud_obj_idx = posData.IDs.index(budID)
- new_moth_obj_idx = posData.IDs.index(new_mothID)
- rp_budID = posData.rp[bud_obj_idx]
- rp_new_mothID = posData.rp[new_moth_obj_idx]
- if rp_budID.area >= rp_new_mothID.area:
+ bug_obj = posData.rp.get_obj_from_ID(budID)
+ new_mother_obj = posData.rp.get_obj_from_ID(new_mothID)
+ if bug_obj.area >= new_mother_obj.area:
self.assignBudMothButton.setChecked(False)
msg = widgets.myMessageBox()
txt = (
@@ -7576,7 +7613,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
if left_click and canBrush:
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
Y, X = lab_2D.shape
# Store undo state before modifying stuff
@@ -7614,7 +7651,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
self.setImageImg2(updateLookuptable=False)
how = self.drawIDsContComboBox.currentText()
- lab2D = self.get_2Dlab(posData.lab)
+ lab2D = self.get_2Dlab(posData.lab, force_z=False)
self.globalBrushMask = np.zeros(lab2D.shape, dtype=bool)
brushMask = localLab == posData.brushID
brushMask = np.logical_and(brushMask, diskMask)
@@ -7627,7 +7664,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
elif left_click and canErase:
x, y = event.pos().x(), event.pos().y()
xdata, ydata = int(x), int(y)
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
Y, X = lab_2D.shape
# Store undo state before modifying stuff
@@ -7859,7 +7896,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
elif right_click and copyContourON:
hoverLostID = self.ax1_lostObjScatterItem.hoverLostID
self.copyLostObjectMask(hoverLostID)
- self.update_rp()
+ self.update_rp(use_curr_view=True) # only visible
self.updateAllImages()
self.store_data()
@@ -7909,7 +7946,7 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
if closeSpline:
self.splineHoverON = False
self.curvToolSplineToObj()
- self.update_rp()
+ self.update_rp() # dont think I can optimize this
if self.autoIDcheckbox.isChecked():
self.trackManuallyAddedObject(posData.brushID, True)
if self.isSnapshot:
@@ -7986,21 +8023,28 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
posData = self.data[self.pos_i]
currentIDs = posData.IDs.copy()
if manualTrackID in currentIDs:
- tempID = max(currentIDs) + 1
- posData.lab[posData.lab == clickedID] = tempID
- posData.lab[posData.lab == manualTrackID] = clickedID
- posData.lab[posData.lab == tempID] = manualTrackID
+ clicked_obj = posData.rp.get_obj_from_ID(clickedID)
+ manual_track_obj = posData.rp.get_obj_from_ID(manualTrackID)
+ posData.lab[clicked_obj.slice][clicked_obj.image] = manualTrackID
+ posData.lab[manual_track_obj.slice][manual_track_obj.image] = (
+ clickedID)
self.manualTrackingToolbar.showWarning(
f'The ID {manualTrackID} already exists --> '
f'ID {manualTrackID} has been swapped with {clickedID}'
)
+ assignments = {clickedID: manualTrackID,
+ manualTrackID: clickedID}
else:
- posData.lab[posData.lab == clickedID] = manualTrackID
+ clicked_obj = posData.rp.get_obj_from_ID(clickedID)
+ posData.lab[clicked_obj.slice][clicked_obj.image] = manualTrackID
self.manualTrackingToolbar.showInfo(
f'ID {clickedID} changed to {manualTrackID}.'
)
+ assignments = {clickedID: manualTrackID}
- self.update_rp()
+ # only ID change, so use assignments
+ # not 3D ready yet? Otherwise I must set assignments to None
+ self.update_rp(assignments=assignments)
self.updateAllImages()
elif right_click and manualBackgroundON:
@@ -8064,9 +8108,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
return
else:
ID = divID_prompt.EntryID
- obj_idx = posData.IDs.index(ID)
- y, x = posData.rp[obj_idx].centroid
- xdata, ydata = int(x), int(y)
+ centroid = posData.rp.get_centroid(ID)
+ ydata, xdata = self.getObjCentroid(centroid)
+ ydata, xdata = int(ydata), int(xdata)
if not self.isSnapshot:
# Store undo state before modifying stuff
@@ -8109,8 +8153,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
ID = budID_prompt.EntryID
obj_idx = posData.IDs.index(ID)
- y, x = posData.rp[obj_idx].centroid
- xdata, ydata = int(x), int(y)
+ centroid = posData.rp.get_centroid(ID)
+ ydata, xdata = self.getObjCentroid(centroid)
+ ydata, xdata = int(ydata), int(xdata)
relationship = posData.cca_df.at[ID, 'relationship']
is_history_known = posData.cca_df.at[ID, 'is_history_known']
@@ -8156,9 +8201,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
return
else:
ID = unknownID_prompt.EntryID
- obj_idx = posData.IDs.index(ID)
- y, x = posData.rp[obj_idx].centroid
- xdata, ydata = int(x), int(y)
+ centroid = posData.rp.get_centroid(ID)
+ ydata, xdata = self.getObjCentroid(centroid)
+ ydata, xdata = int(ydata), int(xdata)
self.annotateIsHistoryKnown(ID)
if not self.setIsHistoryKnownButton.findChild(QAction).isChecked():
@@ -8185,9 +8230,9 @@ def gui_mousePressEventImg1(self, event: QMouseEvent):
return
else:
ID = clickedBkgrDialog.EntryID
- obj_idx = posData.IDs.index(ID)
- y, x = posData.rp[obj_idx].centroid
- xdata, ydata = int(x), int(y)
+ centroid = posData.rp.get_centroid(ID)
+ ydata, xdata = self.getObjCentroid(centroid)
+ ydata, xdata = int(ydata), int(xdata)
button = self.doCustomAnnotation(ID)
if button is None:
@@ -8406,103 +8451,325 @@ def gui_addCreatedAxesItems(self):
self.ax1.exportMaskImageItem = self.exportMaskImageItem
def SegForLostIDsSetSettings(self):
+ posData = self.data[self.pos_i]
+ displayed_input_label = 'Displayed image'
+
+ recipe_json_path = os.path.join(
+ settings_folderpath, 'segmentation_for_lostIDs_recipe.json'
+ )
try:
- prev_model = str(self.df_settings.at['SegForLostIDsModel', 'value'])
+ prev_models = [
+ model.strip() for model in str(
+ self.df_settings.at['SegForLostIDsModel', 'value']
+ ).split(',') if model.strip()
+ ]
except KeyError:
- prev_model = None
- win = apps.QDialogSelectModel(parent=self, customFirst=prev_model)
+ prev_models = []
+
+ has_last_recipe = bool(prev_models) and os.path.exists(recipe_json_path)
+ seg_for_lost_ids_info = (
+ 'Segmentation for lost IDs settings
'
+ 'Use this dialog to define the segmentation workflow used for '
+ 'resegmenting local neighborhood lost IDs. Other already segmented cells are filled '
+ 'with background, which makes even dimm cells seem bright after '
+ 'rescaling before resegmentation. This is especially usefull for '
+ 'cells which have varying intensities over time, like FUCCI cells.
'
+ 'How model selection works
'
+ '- You can select one model or multiple models.
'
+ '- In multi-selection mode you can include the same model multiple '
+ 'times (for example, model A, then model B, then model A again).
'
+ '- After confirming, you can reorder the selected models. The order '
+ 'is the execution order. '
+ '- You then will be asked to set model parameters in the order selected.
'
+ ' - Pay special attention to the additional "Settings for local '
+ 'segmentation" section, here you can for example select any image as input.
'
+ 'Load last selection...
'
+ 'Restores only the list of selected model names (the recipe order '
+ 'selection), then lets you continue configuring parameters.
'
+ 'Load last recipe...
'
+ 'Loads the complete saved recipe from disk, including model order and '
+ 'all model-specific settings (when available).
'
+ 'Add custom model...
'
+ 'Lets you register an additional local custom model and include it in '
+ 'the sequence.
'
+ 'Tip: if you want to run the same model twice with different '
+ 'parameters, add it twice and configure each step independently.'
+ )
+ win = apps.QDialogSelectModel(
+ parent=self,
+ allowMultiSelection=True,
+ lastSelection=prev_models,
+ addSelectLastSelectionButton=bool(prev_models),
+ addSelectLastRecipeButton=has_last_recipe,
+ custom_title='Select model(s) for segmentation of lost IDs',
+ info_label=seg_for_lost_ids_info,
+ )
win.exec_()
if win.cancel:
self.logger.info('Seg for lost IDs cancelled.')
return
- base_model_name = win.selectedModel
- if base_model_name:
- self.df_settings.at['SegForLostIDsModel', 'value'] = base_model_name
+ if getattr(win, 'loadLastRecipe', False):
+ self.logger.info('Loading last segmentation recipe for lost IDs...')
+ try:
+ with open(recipe_json_path, 'r') as f:
+ recipe_data = json.load(f)
+ model_settings = []
+ for entry in recipe_data['models']:
+ model_settings.append({
+ 'win': None,
+ 'init_kwargs_new': entry['init_kwargs_new'],
+ 'args_new': entry['args_new'],
+ 'base_model_name': entry['base_model_name'],
+ 'init_kwargs': entry.get('init_kwargs', {}),
+ 'model_kwargs': entry.get('model_kwargs', {}),
+ 'preproc_recipe': entry.get('preproc_recipe', None),
+ 'applyPostProcessing': entry.get('applyPostProcessing', False),
+ 'standardPostProcessKwargs': entry.get('standardPostProcessKwargs', {}),
+ 'customPostProcessFeatures': entry.get('customPostProcessFeatures', None),
+ 'customPostProcessGroupedFeatures': entry.get('customPostProcessGroupedFeatures', None),
+ })
+ self.SegForLostIDsSettings = {'models_settings': model_settings}
+ # Restore model names in settings
+ restored_models = [
+ 'Automatic thresholding'
+ if m['base_model_name'] == 'thresholding'
+ else m['base_model_name']
+ for m in model_settings
+ ]
+ self.df_settings.at['SegForLostIDsModel', 'value'] = (
+ ', '.join(restored_models)
+ )
+ self.df_settings.to_csv(self.settings_csv_path)
+ self.logger.info('Last segmentation recipe loaded successfully.')
+ except Exception as e:
+ self.logger.error(f'Failed to load last recipe: {e}')
+ return
+
+ selected_models = win.selectedModel
+ if isinstance(selected_models, str):
+ selected_models = [selected_models]
+
+ if not selected_models:
+ self.logger.info('Seg for lost IDs cancelled.')
+ return
+
+ if selected_models:
+ self.df_settings.at['SegForLostIDsModel', 'value'] = (
+ ', '.join(selected_models)
+ )
self.df_settings.to_csv(self.settings_csv_path)
- model_name = 'local_seg'
+ all_extra_params = [
+ 'image_channel_name',
+ 'overlap_threshold',
+ 'padding',
+ 'size_perc_diff',
+ 'distance_filler_growth',
+ 'allow_only_tracked_cells',
+ ]
+ extra_types = {
+ 'overlap_threshold': float,
+ 'padding': float,
+ 'size_perc_diff': float,
+ 'distance_filler_growth': float,
+ 'allow_only_tracked_cells': bool,
+ 'image_channel_name': str,
+ }
+ extra_defaults = {
+ 'overlap_threshold': 0.5,
+ 'padding': 0.8,
+ 'size_perc_diff': 0.3,
+ 'distance_filler_growth': 1.,
+ 'allow_only_tracked_cells': False,
+ 'image_channel_name': displayed_input_label,
+ }
+ extra_desc = {
+ 'overlap_threshold': (
+ 'Overlap threshold with other already segemented cells '
+ 'over which newly segmented cells are discarded'
+ ),
+ 'padding': (
+ 'Padding of the box used for new segmentation around the '
+ 'segmentation from the previous frame'
+ ),
+ 'size_perc_diff': (
+ 'Relative size difference acceptable compared to previous '
+ 'frames'
+ ),
+ 'distance_filler_growth': (
+ 'Cells which are already segmented are filled with random '
+ 'noise sampled from background to ensure that they do not '
+ 'get segmented again. This parameter controls the additional '
+ 'padding around the already segmented cells.'
+ ),
+ 'allow_only_tracked_cells': (
+ 'If no new cell IDs should be permitted '
+ '(based on real time tracking)'
+ ),
+ 'image_channel_name': (
+ 'Image channel used as model input. '
+ 'Select "Displayed image" to use exactly what is currently '
+ 'shown in the viewer, or select a specific fluorescence '
+ 'channel.'
+ ),
+ }
- idx = self.modelNames.index(model_name)
- acdcSegment = self.acdcSegment_li[idx]
+ model_settings = []
+ remembered_extra_args = {}
+ for model_idx, selected_model_name in enumerate(selected_models):
+ model_name = selected_model_name
+ if model_name == 'Automatic thresholding':
+ model_name = 'thresholding'
+ try:
+ if selected_model_name in self.modelNames:
+ idx = self.modelNames.index(selected_model_name)
+ acdcSegment = self.acdcSegment_li[idx]
+ if acdcSegment is None:
+ self.logger.info(f'Importing {model_name}...')
+ acdcSegment = myutils.import_segment_module(model_name)
+ self.acdcSegment_li[idx] = acdcSegment
+ else:
+ self.logger.info(f'Importing {model_name}...')
+ acdcSegment = myutils.import_segment_module(model_name)
+ except (ImportError, KeyError) as e:
+ self.logger.error(f'Error importing {model_name}: {e}')
+ return
- try:
- if acdcSegment is None or base_model_name != self.local_seg_base_model_name:
- self.logger.info(f'Importing {base_model_name}...')
- acdcSegment = myutils.import_segment_module(base_model_name)
- self.acdcSegment_li[idx] = acdcSegment
- self.local_seg_base_model_name = base_model_name
- except (IndexError, ImportError, KeyError) as e:
- self.logger.error(f'Error importing {base_model_name}: {e}')
- return
-
- extra_params = ['overlap_threshold',
- 'padding',
- 'size_perc_diff',
- 'distance_filler_growth',
- 'max_iterations',
- 'allow_only_tracked_cells']
-
- extra_types = [float, float, float, float, int, bool]
-
- extra_defaults = [0.5, 0.8, 0.3, 1., 2, False]
-
- extra_desc = ['Overlap threshold with other already segemented cells over which newly segmented cells are discarded',
- 'Padding of the box used for new segmentation around the segmentation from the previous frame',
- 'Relative size difference acceptable compared to previous frames',
- """Cells which are already segmented are filled with random noise sampled from background
- to ensure that they don't get segmented again.
- This parameter controls the additional padding around the already segmented cells.""",
- """The algorithm will try and segment the maximum amount
- of cells in the image by running the model several
- times and filling new found cells with background noise.
- How many of these iterations should be run?""",
- "If no new cell IDs should be permitted (based on real time tracking)"]
-
- extra_ArgSpec = []
- for i, param in enumerate(extra_params):
- param = ArgSpec(name=param,
- default=extra_defaults[i],
- type=extra_types[i],
- desc=extra_desc[i],
- docstring='')
-
- extra_ArgSpec.append(param)
+ extra_params = all_extra_params
- init_params, segment_params = myutils.getModelArgSpec(acdcSegment)
- segment_params = [arg for arg in segment_params if arg[0] != 'diameter']
-
- extraParamsTitle = 'Settings for local segmentation'
- win = self.initSegmModelParams(
- base_model_name, acdcSegment, init_params, segment_params,
- extraParams=extra_ArgSpec, extraParamsTitle=extraParamsTitle,
- initLastParams=True, ini_filename='segmentation_for_lostIDs.ini',
- )
+ available_fluo_channels = [
+ ch for ch in posData.chNames if ch != self.user_ch_name
+ ]
+ channel_options = [displayed_input_label, *available_fluo_channels]
- if win is None:
- self.logger.info('Segmentation for lost IDs cancelled.')
- return
+ class _SegForLostIDsInputChannelType:
+ values = channel_options
- init_kwargs_new = {}
- args_new = {}
- for key, val in win.init_kwargs.items():
- if key in extra_params:
- args_new[key] = val
- else:
- init_kwargs_new[key] = val
+ extra_types['image_channel_name'] = _SegForLostIDsInputChannelType
+
+ extra_ArgSpec = []
+ for param in extra_params:
+ param_arg = ArgSpec(
+ name=param,
+ default=extra_defaults[param],
+ type=extra_types[param],
+ desc=extra_desc[param],
+ docstring=''
+ )
+ extra_ArgSpec.append(param_arg)
+
+ init_params, segment_params = myutils.getModelArgSpec(acdcSegment)
+ segment_params = [
+ arg for arg in segment_params if arg[0] != 'diameter'
+ ]
+
+ initLastParams = True
+ if model_name == 'thresholding':
+ win_thresh = apps.QDialogAutomaticThresholding(
+ parent=self, isSegm3D=self.isSegm3D
+ )
+ win_thresh.exec_()
+ if win_thresh.cancel:
+ self.logger.info('Segmentation for lost IDs cancelled.')
+ return
+ self.model_kwargs = win_thresh.segment_kwargs
+ thresh_method = self.model_kwargs['threshold_method']
+ gauss_sigma = self.model_kwargs['gauss_sigma']
+ segment_params = myutils.insertModelArgSpec(
+ segment_params, 'threshold_method', thresh_method
+ )
+ segment_params = myutils.insertModelArgSpec(
+ segment_params, 'gauss_sigma', gauss_sigma
+ )
+ initLastParams = False
+
+ extraParamsTitle = (
+ f'Settings for local segmentation '
+ f'({model_idx + 1}/{len(selected_models)})'
+ )
+ win = self.initSegmModelParams(
+ model_name, acdcSegment, init_params, segment_params,
+ extraParams=extra_ArgSpec,
+ extraParamsTitle=extraParamsTitle,
+ initLastParams=initLastParams,
+ ini_filename='segmentation_for_lostIDs.ini',
+ )
- for key, val in win.extra_kwargs.items():
- if key in extra_params:
- args_new[key] = val
+ if win is None:
+ self.logger.info('Segmentation for lost IDs cancelled.')
+ return
+
+ init_kwargs_new = {}
+ args_new = {}
+ for key, val in win.init_kwargs.items():
+ if key in extra_params:
+ args_new[key] = val
+ else:
+ init_kwargs_new[key] = val
+
+ for key, val in win.extra_kwargs.items():
+ if key in extra_params:
+ if key == 'image_channel_name':
+ init_kwargs_new[key] = val
+ else:
+ args_new[key] = val
+
+ for key, val in remembered_extra_args.items():
+ if key == 'image_channel_name':
+ init_kwargs_new.setdefault(key, val)
+ continue
+ args_new.setdefault(key, val)
+
+ if model_idx == 0:
+ remembered_extra_args = args_new.copy()
+ remembered_extra_args['image_channel_name'] = (
+ init_kwargs_new.get('image_channel_name', displayed_input_label)
+ )
+
+ model_settings.append({
+ 'win': win,
+ 'init_kwargs_new': init_kwargs_new,
+ 'args_new': args_new,
+ 'base_model_name': model_name,
+ 'init_kwargs': dict(win.init_kwargs),
+ 'model_kwargs': dict(win.model_kwargs),
+ 'preproc_recipe': win.preproc_recipe,
+ 'applyPostProcessing': win.applyPostProcessing,
+ 'standardPostProcessKwargs': win.standardPostProcessKwargs,
+ 'customPostProcessFeatures': win.customPostProcessFeatures,
+ 'customPostProcessGroupedFeatures': win.customPostProcessGroupedFeatures,
+ })
self.SegForLostIDsSettings = {
- 'win': win,
- 'init_kwargs_new': init_kwargs_new,
- 'args_new': args_new,
- 'base_model_name': base_model_name,
+ 'models_settings': model_settings,
}
+ # Persist recipe to disk so it survives across sessions
+ try:
+ recipe_data = {
+ 'models': [
+ {
+ 'base_model_name': ms['base_model_name'],
+ 'init_kwargs_new': ms['init_kwargs_new'],
+ 'args_new': ms['args_new'],
+ 'init_kwargs': ms['init_kwargs'],
+ 'model_kwargs': ms['model_kwargs'],
+ 'preproc_recipe': ms['preproc_recipe'],
+ 'applyPostProcessing': ms['applyPostProcessing'],
+ 'standardPostProcessKwargs': ms['standardPostProcessKwargs'],
+ 'customPostProcessFeatures': ms['customPostProcessFeatures'],
+ 'customPostProcessGroupedFeatures': ms['customPostProcessGroupedFeatures'],
+ }
+ for ms in model_settings
+ ]
+ }
+ with open(recipe_json_path, 'w') as f:
+ json.dump(recipe_data, f, indent=2, default=str)
+ except Exception as e:
+ self.logger.warning(f'Could not save recipe to disk: {e}')
+
def segForLostIDsButtonClicked(self):
self.setFrameNavigationDisabled(disable=True, why='Segmentation for lost IDs')
@@ -8526,9 +8793,9 @@ def onSegForLostInit(self):
self.SegForLostIDsSetSettings()
self.SegForLostIDsWaitCond.wakeAll()
- def SegForLostIDsWorkerAskInstallModel(self, model_name):
- myutils.check_install_package(model_name)
- self.SegForLostIDsWaitCond.wakeAll()
+ # def SegForLostIDsWorkerAskInstallModel(self, model_name):
+ # myutils.check_install_package(model_name)
+ # self.SegForLostIDsWaitCond.wakeAll()
def startSegForLostIDsWorker(self):
self.SegForLostIDsMutex = QMutex()
@@ -8542,9 +8809,9 @@ def startSegForLostIDsWorker(self):
# Connect the worker's signal to the main thread's slot
self.SegForLostIDsWorker.sigAskInit.connect(self.onSegForLostInit)
- self.SegForLostIDsWorker.sigAskInstallModel.connect(
- self.SegForLostIDsWorkerAskInstallModel
- )
+ # self.SegForLostIDsWorker.sigAskInstallModel.connect(
+ # self.SegForLostIDsWorkerAskInstallModel
+ # )
self.SegForLostIDsWorker.sigshowImageDebug.connect(
self.showImageDebug
)
@@ -8553,8 +8820,13 @@ def startSegForLostIDsWorker(self):
self.SegForLostIDsWorkerAskInstallGPU
)
- self.SegForLostIDsWorker.sigStoreData.connect(self.onSigStoreDataSegForLostIDsWorker)
- self.SegForLostIDsWorker.sigUpdateRP.connect(self.onSigUpdateRPSegForLostIDsWorker)
+ self.SegForLostIDsWorker.sigStoreData.connect(
+ self.onSigStoreDataSegForLostIDsWorker)
+ self.SegForLostIDsWorker.sigUpdateRP.connect(
+ self.onSigUpdateRPSegForLostIDsWorker)
+ self.SegForLostIDsWorker.sigGetSegForLostIDsInputImg.connect(
+ self.onSigGetInputImgSegForLostIDsWorker
+ )
# self.SegForLostIDsWorker.sigGetData.connect(self.onSigGetDataSegForLostIDsWorker)
# self.SegForLostIDsWorker.sigGet2Dlab.connect(self.onSigGet2DlabSegForLostIDsWorker)
# self.SegForLostIDsWorker.sigGetTrackedLostIDs.connect(self.onSigGetTrackedSegForLostIDsWorker)
@@ -8619,6 +8891,31 @@ def onSigTrackManuallyAddedObjectSegForLostIDsWorker(self, added_IDs, isNewID, w
self.trackManuallyAddedObject(added_IDs, isNewID, wl_update=wl_update, wl_track_og_curr=wl_track_og_curr)
self.SegForLostIDsWaitCond.wakeAll()
+ def onSigGetInputImgSegForLostIDsWorker(self, image_channel_name):
+ displayed_input_label = 'Displayed image'
+ posData = self.data[self.pos_i]
+
+ if (
+ not image_channel_name
+ or image_channel_name == displayed_input_label
+ ):
+ img = self.getDisplayedImg1()
+ self.SegForLostIDsWorker.inputImgForSegForLostIDs = img
+ self.SegForLostIDsWaitCond.wakeAll()
+ return
+
+ self.getChData(requ_ch={image_channel_name})
+
+ _, filename = self.getPathFromChName(image_channel_name, posData)
+ fluo_data = posData.fluo_data_dict.get(filename)
+ if posData.SizeT > 1:
+ fluo_img_data = fluo_data[posData.frame_i]
+ else:
+ fluo_img_data = fluo_data
+
+ self.SegForLostIDsWorker.inputImgForSegForLostIDs = fluo_img_data
+ self.SegForLostIDsWaitCond.wakeAll()
+
def onSigStoreData(
self, waitcond, pos_i=None, enforce=True, debug=False,
@@ -8628,10 +8925,17 @@ def onSigStoreData(
autosave=autosave, store_cca_df_copy=store_cca_df_copy)
waitcond.wakeAll()
- def onSigUpdateRP(self, waitcond, draw=True, debug=False, update_IDs=True,
- wl_update=True, wl_track_og_curr=False):
- self.update_rp(draw=draw, debug=debug, update_IDs=update_IDs,
- wl_update=wl_update, wl_track_og_curr=wl_track_og_curr)
+ def onSigUpdateRP(self, waitcond,
+ draw=True, debug=False, # og stuff
+ assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same
+ specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR
+ wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff
+ ):
+ self.update_rp(draw=True, debug=False, # og stuff
+ assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same
+ specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR
+ wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff
+ )
waitcond.wakeAll()
def onSigGetData(self, waitcond, debug=False):
@@ -8640,7 +8944,7 @@ def onSigGetData(self, waitcond, debug=False):
def SegForLostIDsWorkerFinished(self):
self.updateAllImages()
- self.update_rp()
+ self.update_rp() # will update when updating segoforlostIDs
self.store_data(autosave=True)
self.setFrameNavigationDisabled(disable=False, why='Segmentation for lost IDs')
@@ -8649,8 +8953,15 @@ def SegForLostIDsWorkerFinished(self):
self.progressWin.close()
self.progressWin = None
- def showImageDebug(self, img):
- imshow(img)
+ def showImageDebug(self, display_info):
+ title = ''
+ img_titles = None
+ if isinstance(display_info, dict):
+ title = display_info.get('title', '')
+ img_titles = display_info.get('img_titles', None)
+ imgs = display_info.get('images', [])
+ imshow(*imgs, window_title=str(title), figure_title=str(title),
+ axis_titles=img_titles)
def gui_raiseBottomLayoutContextMenu(self, event):
try:
@@ -8990,14 +9301,13 @@ def searchIDworkerCallback(self, posData, searchedID):
for frame_i in range(len(posData.segm_data)):
if frame_i >= len(posData.allData_li):
break
- lab = posData.allData_li[frame_i]['labels']
- if lab is None:
- rp = skimage.measure.regionprops(posData.segm_data[frame_i])
- IDs = set([obj.label for obj in rp])
- else:
- IDs = posData.allData_li[frame_i]['IDs']
- if searchedID in IDs:
+ rp = posData.allData_li[frame_i]['regionprops']
+ if rp is None:
+ lab = posData.segm_data[frame_i]
+ rp = regionprops.acdcRegionprops(lab, precache_centroids=False)
+ posData.allData_li[frame_i]['regionprops'] = rp
+ if searchedID in rp.IDs:
frame_i_found = frame_i
break
@@ -9014,8 +9324,7 @@ def warnIDnotFound(self, searchedID):
def goToObjectID(self, ID):
posData = self.data[self.pos_i]
- objIdx = posData.IDs_idxs[ID]
- obj = posData.rp[objIdx]
+ obj = posData.rp.get_obj_from_ID(ID)
self.goToZsliceSearchedID(obj)
self.highlightSearchedID(ID)
@@ -9026,8 +9335,7 @@ def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)):
posData = self.data[self.pos_i]
frame_i = posData.frame_i
prev_rp = posData.allData_li[frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs']
- obj = prev_rp[prev_IDs_idxs[lostID]]
+ obj = prev_rp.get_obj_from_ID(lostID)
self.goToZsliceSearchedID(obj)
imageItem = self.getLostObjImageItem(0)
@@ -9038,7 +9346,11 @@ def goToLostObjectID(self, lostID, color=(255, 165, 0, 255)):
self.lostObjContoursImage[:] = 0
contours = []
- obj_contours = self.getObjContours(obj, all_external=True)
+ obj_contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
contours.extend(obj_contours)
self.addLostObjsToLostObjImage(obj, lostID)
@@ -9050,8 +9362,7 @@ def goToAcceptedLostObjectID(self, acceptedLostID):
posData = self.data[self.pos_i]
frame_i = posData.frame_i
prev_rp = posData.allData_li[frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs']
- obj = prev_rp[prev_IDs_idxs[acceptedLostID]]
+ obj = prev_rp.get_obj_from_ID(acceptedLostID)
self.goToZsliceSearchedID(obj)
self.updateLostTrackedContoursImage(tracked_lost_IDs=[acceptedLostID])
@@ -9633,7 +9944,8 @@ def applyEditID(
self, clickedID, currentIDs, oldIDnewIDMapper, clicked_x, clicked_y, shift=False, doPropagateUnvisited=False
):
posData = self.data[self.pos_i]
-
+ rp = self.rpCurr2D() if (shift and self.isSegm3D) else posData.rp
+ use_3D_obj_centroid = not (shift and self.isSegm3D)
# Ask to propagate change to all future visited frames
key = 'Edit ID'
askAction = self.askHowFutureFramesActions[key]
@@ -9654,43 +9966,51 @@ def applyEditID(
lab = posData.lab
# Store undo state before modifying stuff
+ # no risk of merging IDs if we are working with rp and dont updaet in the middle...
self.storeUndoRedoStates(UndoFutFrames)
- maxID = max(posData.IDs, default=0)
- for old_ID, new_ID in oldIDnewIDMapper:
+ # could this be chained??? If yes we have to "simplify" to least swops to since we keep RP stale
+ # oldIDnewIDMapper
+ assignments = {}
+ for old_ID, new_ID in oldIDnewIDMapper:
if new_ID in currentIDs and not self.editIDmergeIDs:
- tempID = maxID + 1
- lab[lab == old_ID] = maxID + 1
- lab[lab == new_ID] = old_ID
- lab[lab == tempID] = new_ID
- maxID += 1
-
- old_ID_idx = currentIDs.index(old_ID)
- new_ID_idx = currentIDs.index(new_ID)
-
- # Append information for replicating the edit in tracking
- # List of tuples (y, x, replacing ID)
- objo = posData.rp[old_ID_idx]
- yo, xo = self.getObjCentroid(objo.centroid)
- objn = posData.rp[new_ID_idx]
- yn, xn = self.getObjCentroid(objn.centroid)
- if not math.isnan(yo) and not math.isnan(yn):
- yn, xn = int(yn), int(xn)
- posData.editID_info.append((yn, xn, new_ID))
- yo, xo = int(clicked_y), int(clicked_x)
- posData.editID_info.append((yo, xo, old_ID))
+ objo = rp.get_obj_from_ID(old_ID)
+ objn = rp.get_obj_from_ID(new_ID)
+
+ # Relabel old_ID to new ID, save since rp is "stale"
+ slc_o = objo.slice
+ mask_o = objo.image
+ lab[slc_o][mask_o] = new_ID
+
+ # Relabel new_ID to old_ID
+ slc_n = objn.slice
+ mask_n = objn.image
+ lab[slc_n][mask_n] = old_ID
+
+
+ # ¯\_(ツ)_/¯
+ if use_3D_obj_centroid:
+ objn_centroid = rp.get_centroid(old_ID, exact=True) #
+ yn, xn = self.getObjCentroid(objn_centroid)
+ if not math.isnan(yn):
+ yn, xn = int(yn), int(xn)
+ posData.editID_info.append((yn, xn, new_ID))
+ yo, xo = int(clicked_y), int(clicked_x)
+ posData.editID_info.append((yo, xo, old_ID))
+ assignments[new_ID] = old_ID
+ assignments[old_ID] = new_ID
else:
- lab[lab == old_ID] = new_ID
- if new_ID > maxID:
- maxID = new_ID
- old_ID_idx = posData.IDs.index(old_ID)
-
- # Append information for replicating the edit in tracking
- # List of tuples (y, x, replacing ID)
- obj = posData.rp[old_ID_idx]
- y, x = self.getObjCentroid(obj.centroid)
- if not math.isnan(y) and not math.isnan(y):
- y, x = int(y), int(x)
- posData.editID_info.append((y, x, new_ID))
+ # Use regionprops for old_ID
+ obj = rp.get_obj_from_ID(old_ID)
+ slc = obj.slice
+ mask = obj.image
+ lab[slc][mask] = new_ID
+ if use_3D_obj_centroid:
+ centroid = rp.get_centroid(old_ID, exact=True)
+ y, x = self.getObjCentroid(centroid)
+ if not math.isnan(y) and not math.isnan(x):
+ y, x = int(y), int(x)
+ posData.editID_info.append((y, x, new_ID))
+ assignments[old_ID] = new_ID
self.updateAssignedObjsAcdcTrackerSecondStep(new_ID)
@@ -9698,7 +10018,10 @@ def applyEditID(
self.set_2Dlab(lab)
# Update rps
- self.update_rp()
+ # When shift is active we edited the 2D slice of a 3D lab; the cached
+ # slice/image data in the old RP objects is stale so we must do a full
+ # recompute rather than the fast assignments-only path.
+ self.update_rp(assignments=None if (shift and self.isSegm3D) else assignments)
# Since we manually changed an ID we don't want to repeat tracking
self.setAllTextAnnotations()
@@ -9763,14 +10086,23 @@ def getHoverID(self, xdata, ydata, byPassShiftCheck=False):
ymin, xmin, ymax, xmax, diskMask = self.getDiskMask(xdata, ydata)
posData = self.data[self.pos_i]
- lab_2D = self.get_2Dlab(posData.lab)
+ lab_2D = self.get_2Dlab(posData.lab, force_z=False)
ID = lab_2D[ydata, xdata]
self.isHoverZneighID = False
if self.isSegm3D:
+ zProjHow = self.zProjComboBox.currentText()
+ isZslice = zProjHow == 'single z-slice'
z = self.z_lab()
SizeZ = posData.lab.shape[0]
doNotLinkThroughZ = self.brushButton.isChecked() and shift
- if doNotLinkThroughZ:
+ if not isZslice:
+ # In projection mode, ID comes from the projected 2D label image.
+ if self.brushHoverCenterModeAction.isChecked() or ID>0:
+ hoverID = ID
+ else:
+ masked_lab = lab_2D[ymin:ymax, xmin:xmax][diskMask]
+ hoverID = np.bincount(masked_lab).argmax()
+ elif doNotLinkThroughZ:
if self.brushHoverCenterModeAction.isChecked() or ID>0:
hoverID = ID
else:
@@ -9805,11 +10137,12 @@ def getHoverID(self, xdata, ydata, byPassShiftCheck=False):
else:
hoverIDc = 0
- if hoverIDa > 0:
+ # When clicking directly on an object, prefer current-slice ID.
+ if hoverIDb > 0:
+ hoverID = hoverIDb
+ elif hoverIDa > 0:
hoverID = hoverIDa
self.isHoverZneighID = True
- elif hoverIDb > 0:
- hoverID = hoverIDb
elif hoverIDc > 0:
hoverID = hoverIDc
self.isHoverZneighID = True
@@ -9840,7 +10173,7 @@ def setHoverToolSymbolColor(
shift = modifiers == Qt.ShiftModifier
posData = self.data[self.pos_i]
- Y, X = self.get_2Dlab(posData.lab).shape
+ Y, X = self.get_2Dlab(posData.lab, force_z=False).shape
if not myutils.is_in_bounds(xdata, ydata, X, Y):
return
@@ -10158,7 +10491,9 @@ def smoothAutoContWithSpline(self, n=3):
xxA, yyA = xx[::n], yy[::n]
rr, cc = skimage.draw.polygon(yyA, xxA)
self.autoContObjMask[rr, cc] = 1
- rp = skimage.measure.regionprops(self.autoContObjMask)
+ rp = regionprops.acdcRegionprops(
+ self.autoContObjMask, precache_centroids=False
+ )
if not rp:
return
obj = rp[0]
@@ -10258,14 +10593,6 @@ def annotateIsHistoryKnown(self, ID):
# If the cell with unknown history has a relative ID assigned to it
# we set the cca of it to the status it had BEFORE the assignment
posData.cca_df.loc[relID] = relID_cca
-
- # Update cell cycle info LabelItems
- obj_idx = posData.IDs.index(ID)
- rp_ID = posData.rp[obj_idx]
-
- if relID in posData.IDs:
- relObj_idx = posData.IDs.index(relID)
- rp_relID = posData.rp[relObj_idx]
self.setAllTextAnnotations()
self.drawAllMothBudLines()
@@ -10486,12 +10813,7 @@ def undoBudMothAssignment(self, ID):
posData.cca_df.at[relID, 'generation_num'] = 2
posData.cca_df.at[relID, 'cell_cycle_stage'] = 'G1'
posData.cca_df.at[relID, 'relationship'] = 'mother'
-
- obj_idx = posData.IDs.index(ID)
- relObj_idx = posData.IDs.index(relID)
- rp_ID = posData.rp[obj_idx]
- rp_relID = posData.rp[relObj_idx]
-
+
self.store_cca_df()
# Update cell cycle info LabelItems
@@ -11553,14 +11875,11 @@ def delBorderObj(self, checked):
self.storeUndoRedoStates(False)
posData = self.data[self.pos_i]
- posData.lab = skimage.segmentation.clear_border(
- posData.lab, buffer_size=1
+ edge_ids = myutils.clear_border(posData.lab, return_edge_ids=True) # modifies inplace
+ self.update_rp(deletionIDs=edge_ids)
+ self.update_cca_df_deletedIDs(
+ posData, edge_ids, dropInPast=False, dropInFuture=False
)
- oldIDs = posData.IDs.copy()
- self.update_rp()
- removedIDs = [ID for ID in oldIDs if ID not in posData.IDs]
- if posData.cca_df is not None:
- posData.cca_df = posData.cca_df.drop(index=removedIDs)
self.store_data()
self.updateAllImages()
@@ -11574,7 +11893,7 @@ def delNewObj(self, checked):
if frame_i == 0:
return
- prev_IDs = posData.allData_li[frame_i-1]['IDs']
+ prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs
curr_IDs = posData.IDs
new_IDs = list(set(curr_IDs) - set(prev_IDs))
@@ -11583,7 +11902,7 @@ def delNewObj(self, checked):
lab[del_mask] = 0
posData.lab = lab
- self.update_rp()
+ self.update_rp(deletionIDs=new_IDs)
if posData.cca_df is not None:
posData.cca_df = posData.cca_df.drop(index=new_IDs)
@@ -11602,16 +11921,21 @@ def brushAutoHideToggled(self, checked):
def brushReleased(self):
posData = self.data[self.pos_i]
- self.fillHolesID(posData.brushID, sender='brush')
+ do_auto_fill = self.brushAutoFillCheckbox.isChecked()
+ self.fillHolesID(posData.brushID, sender='brush', enabled=do_auto_fill)
# Update data (rp, etc)
- self.update_rp(update_IDs=self.isNewID,)
+
+ power_brush = self.isPowerBrush()
+ # we have to delay for a second
+ self.update_rp(
+ use_curr_view=True,
+ specific_IDs=posData.brushID if not power_brush else None
+ )
# Repeat tracking
if self.autoIDcheckbox.isChecked():
self.trackManuallyAddedObject(posData.brushID, self.isNewID)
- else:
- self.update_rp(update_IDs=posData.brushID not in posData.IDs_idxs)
# Update images
if self.isNewID:
@@ -11783,7 +12107,7 @@ def delROImoving(self, roi):
def delROImovingFinished(self, roi: pg.ROI):
roi.setPen(color='r')
- self.update_rp()
+ self.update_rp() # get bbox of delROI old and new, run update_rp on both seperately
self.updateAllImages()
QTimer.singleShot(
300, partial(self.updateDelROIinFutureFrames, roi)
@@ -11821,7 +12145,7 @@ def restoreAnnotDelROI(self, roi, enforce=True, draw=True):
delROIs_info['delIDsROI'][idx] = delIDs - restoredIDs
self.set_2Dlab(lab2D)
- self.update_rp()
+ self.update_rp() # get bbox of delROI old and new, run update_rp on both seperately
def restoreDelROIimg1(self, delMaskID, delID, ax=0):
if ax == 0:
@@ -11833,7 +12157,9 @@ def restoreDelROIimg1(self, delMaskID, delID, ax=0):
return
if how.find('contours') != -1:
- rp_delmask = skimage.measure.regionprops(delMaskID.astype(np.uint8))
+ rp_delmask = regionprops.acdcRegionprops(
+ delMaskID.astype(np.uint8), precache_centroids=False
+ )
if len(rp_delmask) > 0:
obj = rp_delmask[0]
self.addObjContourToContoursImage(obj=obj, ax=ax)
@@ -11903,7 +12229,9 @@ def getDelROIlab(self, input_lab_2D=None):
idx = delROIs_info['rois'].index(roi)
delObjROImask = delROIs_info['delMasks'][idx]
delIDsROI = delROIs_info['delIDsROI'][idx]
- delROIlabRp = skimage.measure.regionprops(out_lab)
+ delROIlabRp = regionprops.acdcRegionprops(
+ out_lab, precache_centroids=False
+ )
for delObj in delROIlabRp:
isDelObj = np.any(ROImask[delObj.slice][delObj.image])
if not isDelObj:
@@ -12918,7 +13246,6 @@ def changeMode(self, text):
self.addExistingDelROIs()
self.isFirstTimeOnNextFrame()
self.setEnabledCcaToolbar(enabled=False)
- self.clearComputedContours()
self.realTimeTrackingToggle.setDisabled(False)
self.realTimeTrackingToggle.label.setDisabled(False)
if posData.cca_df is not None:
@@ -12935,9 +13262,7 @@ def changeMode(self, text):
self.modeToolBar.setVisible(True)
self.realTimeTrackingToggle.setDisabled(True)
self.realTimeTrackingToggle.label.setDisabled(True)
- self.computeAllContours()
# RAWR!!!!!
- # self.computeAllObjToObjCostPairs()
if proceed:
self.setEnabledEditToolbarButton(enabled=False)
if self.isSnapshot:
@@ -12959,7 +13284,6 @@ def changeMode(self, text):
self.navigateScrollBar.setMaximum(posData.SizeT)
self.navSpinBox.setMaximum(posData.SizeT)
self.clearGhost()
- self.computeAllContours()
elif mode == 'Custom annotations':
self.setAutoSaveAnnotationsEnabled(True)
self.setSwitchViewedPlaneDisabled(True)
@@ -12972,14 +13296,12 @@ def changeMode(self, text):
self.annotateToolbar.setVisible(True)
self.clearGhost()
self.doCustomAnnotation(0)
- self.computeAllContours()
elif mode == 'Snapshot':
self.setAutoSaveAnnotationsEnabled(True)
self.setSwitchViewedPlaneDisabled(False)
self.reconnectUndoRedo()
self.setEnabledSnapshotMode()
self.doCustomAnnotation(0)
- self.clearComputedContours()
elif mode == 'Normal division: Lineage tree': # Mode activation for lineage tree
# self.startLinTreeIntegrityCheckerWorker() # need to replace (postponed)
proceed = self.initLinTree()
@@ -13280,8 +13602,7 @@ def manualAnnotPast_cb(self, checked):
)
self.editIDspinbox.setValue(hoverID)
try:
- obj_idx = posData.IDs_idxs[hoverID]
- obj = posData.rp[obj_idx]
+ obj = posData.rp.get_obj_from_ID(hoverID)
radius = 0.9 * obj.minor_axis_length / 2 # math.sqrt(obj.area/math.pi)*0.9
self.brushSizeSpinbox.setValue(round(radius))
except Exception as err:
@@ -13384,16 +13705,15 @@ def _copyAllLostObjects_navigateToFrame(self, frame_i):
self.get_data()
self.tracking(wl_update=False)
self.currentLab2D = self.get_2Dlab(posData.lab)
- self.update_rp()
+ self.update_rp() # cannot be more granular as lost obj could be anywhere
self.updateLostNewCurrentIDs()
self.store_data(mainThread=False, autosave=False)
self.lostObjContoursImage[:] = 0
self.lostObjImage[:] = 0
prev_rp = posData.allData_li[frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[frame_i-1]['IDs_idxs'] # need to change this when merging with opt.
for lostID in posData.lost_IDs:
- obj = prev_rp[prev_IDs_idxs[lostID]]
+ obj = prev_rp.get_obj_from_ID(lostID)
self.addLostObjsToLostObjImage(obj, lostID, force=True)
def _copyAllLostObjects_returnToFrame(self, frame_i):
@@ -13495,7 +13815,7 @@ def copyAllLostObjectsWorkerFinished(self, output):
self.blinker.start()
self.copyAllLostObjectsWorkerLoop.exit()
- self.update_rp()
+ self.update_rp() # global op and obj added, no opt imo unless difference pic
self.updateAllImages()
self.store_data()
@@ -13638,7 +13958,9 @@ def clearObjsFreehandRegion(self):
regionLab = transformation.clear_objects_not_in_mask(
regionLab, mask
)
- regionRp = skimage.measure.regionprops(regionLab)
+ regionRp = regionprops.acdcRegionprops(
+ regionLab, precache_centroids=False
+ )
for obj in regionRp:
if np.all(mask[obj.slice][obj.image]):
continue
@@ -13652,7 +13974,9 @@ def clearObjsFreehandRegion(self):
else:
regionLab[..., ~mask] = 0
- regionRp = skimage.measure.regionprops(regionLab)
+ regionRp = regionprops.acdcRegionprops(
+ regionLab, precache_centroids=False
+ )
clearIDs = [obj.label for obj in regionRp]
if not clearIDs:
@@ -13711,7 +14035,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse):
frame_i = start_frame_i + i
lab = posData.allData_li[frame_i]['labels']
store = True
- if lab is None:
+ if lab is None: # no rp update here?
if frame_i >= len(posData.segm_data):
lab = np.zeros_like(posData.segm_data[0])
posData.segm_data = np.append(
@@ -13727,6 +14051,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse):
if store:
posData.frame_i = frame_i
posData.allData_li[frame_i]['labels'] = lab.copy()
+ # no rp update here?
self.get_data()
self.store_data(autosave=False)
@@ -13739,7 +14064,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse):
roiLab, self.labelRoiSlice, posData.lab, posData.brushID
)
- self.update_rp()
+ self.update_rp() # get roi and set as bbox
# Repeat tracking
if self.autoIDcheckbox.isChecked():
@@ -13770,8 +14095,7 @@ def labelRoiDone(self, roiSegmData, isTimeLapse):
def restoreHoverObjBrush(self):
posData = self.data[self.pos_i]
if self.ax1BrushHoverID in posData.IDs:
- obj_idx = posData.IDs_idxs[self.ax1BrushHoverID]
- obj = posData.rp[obj_idx]
+ obj = posData.rp.get_obj_from_ID(self.ax1BrushHoverID)
if not self.isObjVisible(obj.bbox):
return
@@ -13863,15 +14187,20 @@ def setAllIDs(self, onlyVisited=False):
for frame_i in range(len(posData.segm_data)):
if frame_i >= len(posData.allData_li):
break
+
lab = posData.allData_li[frame_i]['labels']
if lab is None and onlyVisited:
break
- if lab is None:
- rp = skimage.measure.regionprops(posData.segm_data[frame_i])
- else:
- rp = posData.allData_li[frame_i]['regionprops']
- posData.allIDs.update([obj.label for obj in rp])
+ rp = posData.allData_li[frame_i]['regionprops']
+ if rp is None:
+ lab = posData.segm_data[frame_i]
+ rp = regionprops.acdcRegionprops(
+ lab, precache_centroids=False
+ )
+ posData.allData_li[frame_i]['regionprops'] = rp
+
+ posData.allIDs.update(rp.IDs)
def countObjectsTimelapse(self):
if self.countObjsWindow is None:
@@ -13920,11 +14249,13 @@ def countObjectsSnapshots(self):
numObjectsCurrentZslice = None
if 'In current z-slice' in activeCategories:
numObjectsCurrentZslice = len(
- skimage.measure.regionprops(self.currentLab2D)
+ regionprops.acdcRegionprops(
+ self.currentLab2D, precache_centroids=False
+ )
)
for pos_i, _posData in enumerate(self.data):
- IDs = _posData.allData_li[0]['IDs']
+ IDs = _posData.allData_li[0]['regionprops'].IDs
if os.path.exists(_posData.acdc_output_csv_path):
numObjectsVisitedPosPrevious += len(IDs)
if IDs:
@@ -13932,7 +14263,9 @@ def countObjectsSnapshots(self):
numObjectsAllPos += len(IDs)
else:
lab = _posData.segm_data[0]
- rp = skimage.measure.regionprops(lab)
+ rp = regionprops.acdcRegionprops(
+ lab, precache_centroids=False
+ )
numObjs = len(rp)
numObjectsAllPos += numObjs
@@ -14814,10 +15147,10 @@ def keyPressEvent(self, ev):
if ev.key() == Qt.Key_Q and self.debug:
try:
from . import _q_debug
- _q_debug.q_debug(self)
+ _q_debug.q_debug(self, ev)
except Exception as err:
printl(traceback.format_exc())
- printl('[ERROR]: Error with "_qdebug" module. See Traceback above.')
+ printl('[ERROR]: Error with "_q_debug" module. See Traceback above.')
pass
if not self.isDataLoaded:
@@ -15159,7 +15492,7 @@ def propagateMergeObjsPast(self, IDs_to_merge):
posData.frame_i = past_frame_i
self.get_data()
- IDs = posData.allData_li[past_frame_i]['IDs']
+ IDs = posData.allData_li[past_frame_i]['regionprops'].IDs
stop_loop = False
for ID in IDs_to_merge:
if ID not in IDs:
@@ -15168,10 +15501,13 @@ def propagateMergeObjsPast(self, IDs_to_merge):
if ID == 0:
continue
- posData.lab[posData.lab==ID] = self.firstID
- self.update_rp()
-
- self.store_data(autosave=False)
+ obj = posData.rp.get_obj_from_ID(ID)
+ posData.lab[obj.slice][obj.image] = self.firstID
+
+ preloaded_bbox = self.update_rp_get_bbox(specific_IDs=IDs_to_merge,use_bbox=True) # use old RP to get the correct bbox
+ specific_IDs = [*IDs_to_merge, self.firstID]
+ self.update_rp(preloaded_bbox=preloaded_bbox, specific_IDs=specific_IDs)
+ self.store_data(autosave=False)
if stop_loop:
break
@@ -15210,11 +15546,10 @@ def propagateChange(
# Stop at last visited frame since includeUnvisited = False
break
else:
- lab = posData.segm_data[i]
+ IDs = posData.allData_li[i]['regionprops'].IDs
else:
- lab = posData.allData_li[i]['labels']
-
- if modID in lab:
+ IDs = posData.allData_li[i]['regionprops'].IDs
+ if modID in IDs:
areFutureIDs_affected.append(True)
if not last_tracked_i_found:
@@ -15694,13 +16029,19 @@ def warnTrackerInputNotValid(self, trackerName, warningText):
def repeatTracking(self):
posData = self.data[self.pos_i]
- prev_lab = self.get_2Dlab(posData.lab).copy()
- self.tracking(enforce=True, DoManualEdit=False)
+ tracked_lab, assignments = self.tracking(
+ enforce=True,
+ DoManualEdit=False,
+ return_assignments=True,
+ return_lab=True
+ )
+ posData.lab = tracked_lab
if posData.editID_info:
+ lab2D = self.get_2Dlab(posData.lab)
editedIDsInfo = {
- posData.lab[y,x]:newID
+ lab2D[y,x]:newID
for y, x, newID in posData.editID_info
- if posData.lab[y,x] != newID
+ if lab2D[y,x] != newID
}
editedIDsInfoItems = [
f'ID {oldID} --> {newID}'
@@ -15726,18 +16067,26 @@ def repeatTracking(self):
detailsText=editIDul
)
if msg.cancel:
+ self.update_rp(assignments=assignments) # rp now stale as we return img
return
if msg.clickedButton == keepManualEditButton:
- allIDs = [obj.label for obj in posData.rp]
+ allIDs = posData.rp.IDs
lab2D = self.get_2Dlab(posData.lab)
- self.manuallyEditTracking(lab2D, allIDs)
- self.update_rp()
+ tracked_lab, assignments = self.manuallyEditTracking(
+ lab2D, assignments) # here not use tracked lab?
+ self.update_rp(assignments=assignments) # rp now stale as we return img
self.setAllTextAnnotations()
self.highlightLostNew()
# self.checkIDsMultiContour()
else:
+ self.update_rp(assignments=assignments) # rp now stale as we return img
posData.editID_info = []
- if np.any(posData.lab != prev_lab):
+ else:
+ self.update_rp(assignments=assignments)
+
+ # filter self assignments
+ assignments = {k: v for k, v in assignments.items() if k != v}
+ if assignments:
if self.isSnapshot:
self.fixCcaDfAfterEdit('Repeat tracking')
self.updateAllImages()
@@ -15809,8 +16158,7 @@ def initManualBackgroundObject(self, ID=None):
self.manualBackgroundObjItem.clear()
return
- ID_idx = posData.IDs_idxs[ID]
- self.manualBackgroundObj = posData.rp[ID_idx]
+ self.manualBackgroundObj = posData.rp.get_obj_from_ID(ID)
self.manualBackgroundToolbar.clearInfoText()
self.manualBackgroundObj.contour = self.getObjContours(
@@ -16908,15 +17256,14 @@ def doCustomAnnotation(self, ID):
xx, yy = [], []
for annotID in annotIDs_frame_i:
- if annotID not in posData.IDs_idxs:
+ if annotID not in posData.rp.IDs:
continue
-
- obj_idx = posData.IDs_idxs[annotID]
- obj = posData.rp[obj_idx]
+ obj = posData.rp.get_obj_from_ID(annotID)
acdc_df.at[annotID, state['name']] = 1
if not self.isObjVisible(obj.bbox):
continue
- y, x = self.getObjCentroid(obj.centroid)
+ y, x = self.getObjCentroid(
+ posData.rp.get_centroid(annotID, exact=True))
xx.append(x)
yy.append(y)
@@ -17564,11 +17911,13 @@ def segmWorkerFinished(self, lab, exec_time):
self.update_rp(wl_update=False)
self.tracking(enforce=True, against_next=posData.frame_i==0)
- if self.isSnapshot:
- self.fixCcaDfAfterEdit('Repeat segmentation')
- self.updateAllImages()
- else:
- self.warnEditingWithCca_df('Repeat segmentation')
+ proceed = self.checkHandleTooManyNewItems()
+ if proceed:
+ if self.isSnapshot:
+ self.fixCcaDfAfterEdit('Repeat segmentation')
+ self.updateAllImages()
+ else:
+ self.warnEditingWithCca_df('Repeat segmentation')
txt = f'Done. Segmentation computed in {exec_time:.3f} s'
self.logger.info('-----------------')
@@ -17788,7 +18137,7 @@ def zoomToCells(self, enforce=False):
posData = self.data[self.pos_i]
lab_mask = (self.currentLab2D>0).astype(np.uint8)
- rp = skimage.measure.regionprops(lab_mask)
+ rp = regionprops.acdcRegionprops(lab_mask, precache_centroids=False)
if not rp:
Y, X = lab_mask.shape
xRange = -0.5, X+0.5
@@ -18203,9 +18552,8 @@ def warnLostObjects(self, do_warn=True):
posData.accepted_lost_IDs[frame_i].extend(posData.lost_IDs)
# This section is adding the lost cells to tracked_lost_centroids... TBH I dont know why this wasnt done in the first place
prev_rp = posData.allData_li[posData.frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs']
accepted_lost_centroids = {
- tuple(int(val) for val in prev_rp[prev_IDs_idxs[ID]].centroid)
+ tuple(int(val) for val in prev_rp.get_centroid(ID, exact=True))
for ID in posData.lost_IDs
}
try:
@@ -18285,8 +18633,7 @@ def checkIfFutureFrameManualAnnotPastFrames(self):
self.statusBarLabel.setText(f'{warn_txt}
')
return False
-
- # @exec_time
+
def next_frame(self, warn=True):
proceed = self.checkIfFutureFrameManualAnnotPastFrames()
if not proceed:
@@ -18350,7 +18697,6 @@ def next_frame(self, warn=True):
)
return
- self.store_zslices_rp()
self.resetExpandLabel()
self.updateAllImages()
self.updateHighlightedAxis()
@@ -18815,6 +19161,7 @@ def loadSelectedData(self, user_ch_file_paths, user_ch_name):
create_new_segm=self.isNewFile,
new_endname=self.newSegmEndName,
end_filename_segm=selectedSegmEndName,
+ load_segm_info_ini=True
)
self.selectedSegmEndName = selectedSegmEndName
self.labelBoolSegm = posData.labelBoolSegm
@@ -20240,7 +20587,7 @@ def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False):
xxS, yyS = self.curvPlotItem.getData()
if xxS is None:
self.setUncheckedAllButtons()
- return
+ return None, None
self.smoothAutoContWithSpline()
xxS, yyS = self.getClosedSplineCoords()
@@ -20264,6 +20611,7 @@ def curvToolSplineToObj(self, xxA=None, yyA=None, isRightClick=False):
lab2D[newIDMask] = curvToolID
self.set_2Dlab(lab2D)
self.currentLab2D = lab2D
+ return newIDMask, curvToolID
def addFluoChNameContextMenuAction(self, ch_name):
posData = self.data[self.pos_i]
@@ -20525,6 +20873,10 @@ def framesScrollBarMoved(self, frame_n):
posData.lab = posData.segm_data[posData.frame_i]
else:
posData.lab = np.zeros_like(posData.segm_data[0])
+ rp = regionprops.acdcRegionprops(posData.lab, precache_centroids=False)
+ posData.rp = rp
+ posData.IDs = []
+ posData.allData_li[posData.frame_i]['regionprops'] = rp
else:
posData.lab = posData.allData_li[posData.frame_i]['labels']
@@ -20547,7 +20899,10 @@ def framesScrollBarMoved(self, frame_n):
def framesScrollBarReleased(self, do_store_data=False):
posData = self.data[self.pos_i]
- if posData.frame_i == self.navigateScrollBar.sliderPosition()-1:
+ if (
+ posData.frame_i == self.navigateScrollBar.sliderPosition()-1
+ and self.navigateScrollBarStartedMoving
+ ):
# Slider released without changing value --> do nothing
return
@@ -20575,37 +20930,67 @@ def getStoredSegmData(self):
segm_data.append(lab)
return np.array(segm_data)
- def trackNewIDtoNewIDsFutureFrame(self, newID, newIDmask):
+ def trackNewIDtoNewIDsFutureFrame(self, newID, obj, assignments):
+ # here RP is stale
posData = self.data[self.pos_i]
try:
nextLab = posData.allData_li[posData.frame_i+1]['labels']
except IndexError:
# This is last frame --> there are no future frames
- return
+ return None, assignments
if nextLab is None:
- return
+ return None, assignments
+
+ if obj is None:
+ return None, assignments
+
- newID_lab = np.zeros_like(posData.lab)
- newID_lab[newIDmask] = newID
- newLab_rp = [posData.rp[posData.IDs_idxs[newID]]]
- newLab_IDs = [newID]
nextRp = posData.allData_li[posData.frame_i+1]['regionprops']
+ nextLab = posData.allData_li[posData.frame_i+1]['labels']
+ reverse_assignments = {v:k for k, v in assignments.items()}
- tracked_lab = self.trackFrame(
- nextLab, nextRp, newID_lab, newLab_rp, newLab_IDs,
- assign_unique_new_IDs=False
+ rp = posData.rp
+ lab = posData.lab
+
+ # make rp remporarliy not stale anymore
+ rp.update_regionprops_via_assignments(assignments, lab)
+ tracked_lab, assignments_new = self.trackFrame(
+ nextLab, nextRp, lab, rp, rp.IDs,
+ assign_unique_new_IDs=False, return_assignments=True,
+ specific_IDs=[newID],
)
- trackedID = tracked_lab[newID_lab>0][0]
+ # restore rp
+ posData.rp.update_regionprops_via_assignments(reverse_assignments, lab)
+
+ # clear self assignments
+ assignments_new = {
+ k:v for k, v in assignments_new.items() if k != v
+ }
+ if not assignments_new:
+ return None, assignments
+
+ trackedIDs = list(assignments_new.values())
+
+ trackedID = trackedIDs[0]
if trackedID == newID:
# Object does not exist in future frame --> do not track
- return
+ return None, assignments
- if posData.IDs_idxs.get(trackedID) is not None:
+ if posData.rp.get_obj_from_ID(trackedID, warn=False) is not None:
# Tracked ID already exists --> do not track to avoid merging
- return
+ return None, assignments
- return trackedID
+
+
+ # update assignments
+ assignments = {
+ old_ID: tracked_ID for old_ID, tracked_ID in assignments.items()
+ if old_ID != newID
+ }
+ assignments[newID] = trackedID
+
+ return trackedID, assignments
def store_manual_annot_data(
self, posData=None, data_frame_i=None
@@ -20648,21 +21033,19 @@ def store_data(
# self.lin_tree_ask_changes()
allData_li = posData.allData_li[posData.frame_i]
- allData_li['regionprops'] = posData.rp.copy()
+
+
+ allData_li['regionprops'] = posData.rp
allData_li['labels'] = posData.lab.copy()
- allData_li['IDs'] = posData.IDs.copy()
+ allData_li['regionprops'].IDs = posData.IDs
allData_li['manualBackgroundLab'] = (
posData.manualBackgroundLab
)
- allData_li['IDs_idxs'] = (
- posData.IDs_idxs.copy()
- )
if self.manualAnnotPastButton.isChecked():
self.store_manual_annot_data(
- posData=posData, data_frame_i=allData_li
+ posData=posData, data_frame_i=allData_li
)
- self.store_zslices_rp()
# Store dynamic metadata
is_cell_dead_li = [False]*len(posData.rp)
@@ -20678,13 +21061,17 @@ def store_data(
is_cell_dead_li[i] = obj.dead
is_cell_excluded_li[i] = obj.excluded
IDs[i] = obj.label
- try:
- xx_centroid[i] = int(self.getObjCentroid(obj.centroid)[1])
- yy_centroid[i] = int(self.getObjCentroid(obj.centroid)[0])
- except Exception as err:
- printl(obj, obj.centroid, obj.label, posData.frame_i)
+ centroid = posData.rp.get_centroid(obj.label, exact=True)
+ if centroid is None:
+ continue
+
if self.isSegm3D:
- zz_centroid[i] = int(obj.centroid[0])
+ zz_centroid[i] = int(centroid[0])
+ xx_centroid[i] = int(centroid[2])
+ yy_centroid[i] = int(centroid[1])
+ else:
+ xx_centroid[i] = int(centroid[1])
+ yy_centroid[i] = int(centroid[0])
if obj.label in editedNewIDs:
areManuallyEdited[i] = 1
@@ -21366,7 +21753,18 @@ def get_2Dlab(self, lab, force_z=True):
if isZslice:
return lab[self.z_lab()]
else:
- return lab.max(axis=0)
+ if self.switchPlaneCombobox.isEnabled():
+ slicing = self.switchPlaneCombobox.depthAxes()
+ else:
+ slicing = 'z'
+
+ posData = self.data[self.pos_i]
+ rp = getattr(posData, 'rp', None)
+ if rp is not None and rp.is3D:
+ return rp.get_projection_lab_sorted(slicing=slicing)
+
+ rp = regionprops.acdcRegionprops(lab, precache_centroids=False)
+ return rp.get_projection_lab_sorted(slicing=slicing)
else:
return lab
@@ -21455,24 +21853,13 @@ def applyBrushMask(self, mask, ID, toLocalSlice=None):
posData.lab[mask] = ID
def assignNewIDfromClickedID(
- self, clickedID: int, event: QGraphicsSceneMouseEvent
+ self, clickedID: int, event: QGraphicsSceneMouseEvent, shift: bool = False
):
posData = self.data[self.pos_i]
x, y = event.pos().x(), event.pos().y()
newID = self.setBrushID(return_val=True)
mapper = [(clickedID, newID)]
- self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y)
-
- def get_2Drp(self, lab=None):
- if self.isSegm3D:
- if lab is None:
- # self.currentLab2D is defined at self.setImageImg2()
- lab = self.currentLab2D
- lab = self.get_2Dlab(lab)
- rp = skimage.measure.regionprops(lab)
- return rp
- else:
- return self.data[self.pos_i].rp
+ self.applyEditID(clickedID, posData.IDs.copy(), mapper, x, y, shift=shift)
def set_2Dlab(self, lab2D, lab3D=None):
posData = self.data[self.pos_i]
@@ -21555,6 +21942,11 @@ def get_labels(
else:
shape = (posData.SizeY, posData.SizeX)
labels = np.zeros(shape, dtype=np.uint32)
+ rp = regionprops.acdcRegionprops(labels, precache_centroids=False)
+ if frame_i == posData.frame_i:
+ posData.rp = rp
+ posData.IDs = []
+ posData.allData_li[frame_i]['regionprops'] = rp
return_copy = False
if return_copy:
@@ -21568,10 +21960,12 @@ def get_labels(
def addYXcentroidToDf(self, df):
posData = self.data[self.pos_i]
for obj in posData.rp:
- y_centroid = int(self.getObjCentroid(obj.centroid)[0])
- x_centroid = int(self.getObjCentroid(obj.centroid)[1])
- df.at[obj.label, 'y_centroid'] = y_centroid
- df.at[obj.label, 'x_centroid'] = x_centroid
+ ID = obj.label
+ centroid = posData.rp.get_centroid(obj, exact=True)
+ y_centroid = int(self.getObjCentroid(centroid)[0])
+ x_centroid = int(self.getObjCentroid(centroid)[1])
+ df.at[ID, 'y_centroid'] = y_centroid
+ df.at[ID, 'x_centroid'] = x_centroid
return df
def _get_editID_info(self, df):
@@ -21651,7 +22045,10 @@ def _get_data_unvisited(self, posData, debug=False, lin_tree_init=True,):
posData.lab = self.apply_manual_edits_to_lab_if_needed(
labels
)
- posData.rp = skimage.measure.regionprops(posData.lab)
+ posData.rp = posData.allData_li[posData.frame_i]['regionprops']
+ if posData.rp is None:
+ posData.rp = regionprops.acdcRegionprops(labels, precache_centroids=False)
+ # get stored IDs
self.setManualBackgroundLab()
if posData.acdc_df is not None:
@@ -21694,7 +22091,8 @@ def _get_data_visited(self, posData, debug=False, lin_tree_init=True,):
# Requested frame was already visited. Load from RAM.
never_visited = False
posData.lab = self.get_labels(from_store=True)
- posData.rp = skimage.measure.regionprops(posData.lab)
+ # posData.rp = regionprops.acdcRegionprops(posData.lab, precache_centroids=False)
+ posData.rp = posData.allData_li[posData.frame_i]['regionprops']
df = posData.allData_li[posData.frame_i]['acdc_df']
if df is None:
posData.binnedIDs = set()
@@ -21734,6 +22132,7 @@ def get_data(self, debug=False, lin_tree_init=True):
else:
self.undoAction.setDisabled(True)
self.UndoCount = 0
+
# If stored labels is None then it is the first time we visit this frame
if posData.allData_li[posData.frame_i]['labels'] is None:
proceed_cca, never_visited = self._get_data_unvisited(
@@ -21746,12 +22145,12 @@ def get_data(self, debug=False, lin_tree_init=True):
posData, lin_tree_init=lin_tree_init, debug=debug
)
+ if posData.rp is None: #
+ rp = regionprops.acdcRegionprops(posData.lab, precache_centroids=False)
+ posData.rp = rp
+ posData.allData_li[posData.frame_i]['regionprops'] = rp
self.update_rp_metadata(draw=False)
- posData.IDs = [obj.label for obj in posData.rp]
- posData.IDs_idxs = {
- ID:i for ID, i in zip(posData.IDs, range(len(posData.IDs)))
- }
- self.get_zslices_rp()
+ posData.IDs = posData.rp.IDs
self.pointsLayerDfsToData(posData)
return proceed_cca, never_visited
@@ -21777,7 +22176,7 @@ def addIDBaseCca_df(self, posData, ID):
def getBaseCca_df(self, with_tree_cols=False):
posData = self.data[self.pos_i]
- IDs = [obj.label for obj in posData.rp]
+ IDs = posData.rp.IDs
cca_df = core.getBaseCca_df(IDs, with_tree_cols=with_tree_cols)
return cca_df
@@ -22339,6 +22738,32 @@ def get_cca_df(self, frame_i=None, return_df=False, debug=False):
return cca_df
else:
posData.cca_df = cca_df
+
+ def _changeIDhelper(self, lab, oldID, newID, rp, assignments):
+ did_find_newID = False
+ if newID in rp.IDs: # should here also self.editIDmergeIDs?
+ # Relabel old_ID to tempID, safe as RP is safe so no merging
+ objo = rp.get_obj_from_ID(oldID, warn=False)
+ if objo is not None:
+ slc_o = objo.slice
+ mask_o = objo.image
+ lab[slc_o][mask_o] = newID
+ assignments[oldID] = newID
+ # Relabel new_ID to old_ID
+ objn = rp.get_obj_from_ID(newID) # here warn, we check in the if if it should be there
+ objn_slice = objn.slice
+ objn_mask = objn.image
+ lab[objn_slice][objn_mask] = oldID
+ assignments[newID] = oldID
+ did_find_newID = True
+ else:
+ obj = rp.get_obj_from_ID(oldID, warn=False)
+ if obj is not None:
+ slc = obj.slice
+ mask = obj.image
+ lab[slc][mask] = newID
+ assignments[oldID] = newID
+ return did_find_newID
def changeIDfutureFrames(
self, endFrame_i, oldIDnewIDMapper, includeUnvisited,
@@ -22355,6 +22780,7 @@ def changeIDfutureFrames(
segmSizeT = len(posData.segm_data)
for i in range(posData.frame_i+1, segmSizeT):
+ assignments = {}
lab = posData.allData_li[i]['labels']
if lab is None and not includeUnvisited:
self.enqAutosave()
@@ -22366,49 +22792,47 @@ def changeIDfutureFrames(
self.get_data(lin_tree_init=False)
if shift and self.isSegm3D:
lab = self.get_2Dlab(posData.lab)
+ rp = self.rpCurr2D()
else:
lab = posData.lab
-
+ rp = posData.rp
+
if self.onlyTracking:
self.tracking(enforce=True)
elif not posData.IDs:
continue
else:
- maxID = max(posData.IDs, default=0) + 1
for old_ID, new_ID in oldIDnewIDMapper:
- if new_ID in lab:
- tempID = maxID + 1 # lab.max() + 1
- lab[lab == old_ID] = tempID
- lab[lab == new_ID] = old_ID
- lab[lab == tempID] = new_ID
- maxID += 1
- else:
- lab[lab == old_ID] = new_ID
-
+ self._changeIDhelper(
+ lab, old_ID, new_ID, rp, assignments)
+
if shift and self.isSegm3D:
self.set_2Dlab(lab)
-
- self.update_rp(draw=False)
+
+ self.update_rp(
+ draw=False,
+ assignments=assignments if not (shift and self.isSegm3D) else None)
self.store_data(autosave=i==endFrame_i)
elif includeUnvisited:
# Unvisited frame (includeUnvisited = True)
lab = posData.segm_data[i]
if shift and self.isSegm3D:
lab = self.get_2Dlab(lab)
+ rp = self.rpCurr2D(frame_i=i)
else:
lab = lab
-
+ rp = posData.allData_li[i]['regionprops']
+
+ assignments = {}
for old_ID, new_ID in oldIDnewIDMapper:
- if new_ID in lab:
- tempID = lab.max() + 1
- lab[lab == old_ID] = tempID
- lab[lab == new_ID] = old_ID
- lab[lab == tempID] = new_ID
- else:
- lab[lab == old_ID] = new_ID
-
+ self._changeIDhelper(
+ lab, old_ID, new_ID, rp, assignments)
+
if shift and self.isSegm3D:
posData.segm_data[i][self.z_lab()] = lab
+ rp.update_regionprops(lab)
+ else:
+ rp.update_regionprops_via_assignments(assignments, lab)
# Back to current frame
posData.frame_i = self.current_frame_i
@@ -22682,14 +23106,11 @@ def drawObjMothBudLines(self, obj, posData, ax=0):
scatterItem = self.getMothBudLineScatterItem(ax, isNew)
relative_ID = cca_df_ID['relative_ID']
- try:
- relative_rp_idx = posData.IDs_idxs[relative_ID]
- except KeyError:
- return
-
- relative_ID_obj = posData.rp[relative_rp_idx]
- y1, x1 = self.getObjCentroid(obj.centroid)
- y2, x2 = self.getObjCentroid(relative_ID_obj.centroid)
+ relative_ID_obj = posData.rp.get_obj_from_ID(relative_ID)
+ obj_centroid = posData.rp.get_centroid(ID)
+ rel_obj_centroid = posData.rp.get_centroid(relative_ID)
+ y1, x1 = self.getObjCentroid(obj_centroid)
+ y2, x2 = self.getObjCentroid(rel_obj_centroid)
xx, yy = core.get_line(y1, x1, y2, x2, dashed=True)
scatterItem.addPoints(xx, yy)
@@ -22731,14 +23152,14 @@ def drawAllLineageTreeLines(self):
continue
for ID in new_cells:
- curr_obj = myutils.get_obj_by_label(rp, ID)
+ curr_obj = rp.get_obj_from_ID(ID)
lin_tree_df_ID = lin_tree_df.loc[ID]
# lin_tree_df_mother_ID = lin_tree_df_prev.loc[lin_tree_df_ID["parent_ID_tree"]]
if lin_tree_df_ID["parent_ID_tree"] == -1: # make sure that new obj where the parents are not known get skipped
continue
- mother_obj = myutils.get_obj_by_label(prev_rp, lin_tree_df_ID["parent_ID_tree"])
+ mother_obj = prev_rp.get_obj_from_ID(lin_tree_df_ID["parent_ID_tree"])
emerg_frame_i = lin_tree_df_ID["emerg_frame_i"]
isNew = emerg_frame_i == frame_i
@@ -22774,9 +23195,15 @@ def drawObjLin_TreeMothBudLines(self, ax, obj, mother_obj, isNew, ID=None):
return
scatterItem = self.getMothBudLineScatterItem(ax, isNew)
-
- y1, x1 = self.getObjCentroid(obj.centroid)
- y2, x2 = self.getObjCentroid(mother_obj.centroid)
+
+ posData = self.data[self.pos_i]
+ prev_rp = posData.allData_li[posData.frame_i-1]['regionprops']
+ rp = posData.rp
+ if ID is None:
+ ID = obj.label
+ ID_mother = mother_obj.label
+ y1, x1 = self.getObjCentroid(rp.get_centroid(ID))
+ y2, x2 = self.getObjCentroid(prev_rp.get_centroid(ID_mother))
xx, yy = core.get_line(y1, x1, y2, x2, dashed=True)
scatterItem.addPoints(xx, yy)
@@ -22809,69 +23236,16 @@ def getObjOptsSegmLabels(self, obj):
objOpts = self.getObjTextAnnotOpts(obj, 'Draw only IDs', ax=1)
return objOpts
-
- def store_zslices_rp(self, force_update=False):
- if not self.isSegm3D:
- return
-
- posData = self.data[self.pos_i]
- are_zslices_rp_stored = (
- posData.allData_li[posData.frame_i].get('z_slices_rp') is not None
- )
- if force_update or not are_zslices_rp_stored:
- self._update_zslices_rp()
-
- posData.allData_li[posData.frame_i]['z_slices_rp'] = posData.zSlicesRp
def removeObjectFromRp(self, delID):
posData = self.data[self.pos_i]
- rp = []
- IDs = []
- IDs_idxs = {}
- idx = 0
- for obj in posData.rp:
- if obj.label == delID:
- continue
- rp.append(obj)
- IDs.append(obj.label)
- IDs_idxs[obj.label] = idx
- idx += 1
-
- posData.rp = rp
- posData.IDs = IDs
- posData.IDs_idxs = IDs_idxs
-
- if not self.isSegm3D:
- return
-
- zSlicesRp = {}
- for z, zSliceRp in posData.zSlicesRp.items():
- if delID in zSliceRp:
- continue
+ if not isinstance(delID, (list, set, tuple)):
+ delIDs = [delID]
+ else:
+ delIDs = list(delID)
- zSlicesRp[z] = zSlicesRp
-
- posData.zSlicesRp = zSlicesRp
- self.store_zslices_rp(force_update=True)
-
- def get_zslices_rp(self):
- if not self.isSegm3D:
- return
-
- posData = self.data[self.pos_i]
- self.store_zslices_rp()
- posData.zSlicesRp = posData.allData_li[posData.frame_i]['z_slices_rp']
-
- # @exec_time
- def _update_zslices_rp(self):
- if not self.isSegm3D:
- return
-
- posData = self.data[self.pos_i]
- posData.zSlicesRp = {}
- for z, lab2d in enumerate(posData.lab):
- lab2d_rp = skimage.measure.regionprops(lab2d)
- posData.zSlicesRp[z] = {obj.label:obj for obj in lab2d_rp}
+ posData.rp.update_regionprops_via_deletions(set(delIDs))
+ posData.IDs = posData.rp.IDs
def instructHowDeleteID(self):
if 'showInfoDeleteObject' not in self.df_settings.index:
@@ -22914,7 +23288,7 @@ def checkWarnDeletedIDwithEraser(self):
for ID in self.erasedIDs:
if ID == 0:
continue
- if ID in posData.IDs_idxs:
+ if posData.rp.get_obj_from_ID(ID, warn=False) is not None:
continue
self.instructHowDeleteID()
@@ -22928,36 +23302,228 @@ def checkWarnDeletedIDwithEraser(self):
return True
return False
+
+ def _get_entire_depth_axis_from_2D_cutout(self, cutout):
+ # cutout = (xl, xr), (yt, yb), z is always on the y position if depth axis is changed
+ # cutout is in the current view; return grouped ranges in the order
+ # expected by update_rp_get_bbox before conversion to NumPy bbox order.
+ posData = self.data[self.pos_i]
+ if self.isSegm3D:
+ depthAxes = self.switchPlaneCombobox.depthAxes()
+ if depthAxes == 'z':
+ # cutout is (x, y) and we prepend the full z range.
+ z_max = posData.SizeZ
+ return ((0, z_max), cutout[0], cutout[1])
+ if depthAxes == 'y':
+ # cutout is (x, z); convert to (z, x, y).
+ y_max = posData.SizeY
+ return (cutout[1], cutout[0], (0, y_max))
+ elif depthAxes == 'x':
+ # cutout is (y, z); convert to (z, x, y).
+ x_max = posData.SizeX
+ return (cutout[1], (0, x_max), cutout[0])
+ else:
+ return cutout
+
+ def _cutout_to_bbox(self, cutout):
+ """
+ Convert grouped view ranges into a flat bbox in NumPy array order.
+ 2D input: ((x_min, x_max), (y_min, y_max)) → (y_min, x_min, y_max, x_max)
+ 3D input: ((z_min, z_max), (y_min, y_max), (x_min, x_max)) → (z_min, y_min, x_min, z_max, y_max, x_max)
+ """
+ cutout = tuple(
+ (min(r[0], r[1]), max(r[0], r[1])) for r in cutout
+ )
+ if self.isSegm3D:
+ (z_min, z_max), (y_min, y_max), (x_min, x_max) = cutout
+ return (z_min, y_min, x_min, z_max, y_max, x_max)
+ else:
+ (x_min, x_max), (y_min, y_max) = cutout
+ return (y_min, x_min, y_max, x_max)
+
+ def _get_perc_cutout_from_total_img(self, cutout):
+ posData = self.data[self.pos_i]
+ single_timepoint_segm_size = posData.getSingleTimepointSegmSize()
+ if self.isSegm3D:
+ size = (cutout[0][1] - cutout[0][0]) * (cutout[1][1] - cutout[1][0]) * (cutout[2][1] - cutout[2][0])
+ else:
+ size = (cutout[0][1] - cutout[0][0]) * (cutout[1][1] - cutout[1][0])
+ return size / single_timepoint_segm_size
+
+ def update_rp_get_bbox(self, custom_bbox=None, use_bbox=False, use_curr_view=False,
+ specific_IDs=None, add_frac_custom_bbox=0.05):
+ """
+ Returns an expanded bounding box (bbox) for the given IDs or custom_bbox.
+ Returns False if not enough cells or cutout is too large.
+ """
+ posData = self.data[self.pos_i]
+ if len(posData.rp.IDs) < RP_OPT_NUM_CELLS_MIN:
+ return False
+ if not isinstance(specific_IDs, (list, set, np.ndarray)) and specific_IDs is not None:
+ specific_IDs = [specific_IDs]
+ elif specific_IDs is not None and len(specific_IDs) == 0:
+ specific_IDs = None
+
+ # Helper to merge bboxes
+ def merge_bbox(b1, b2):
+ if len(b1) == 4:
+ return (
+ min(b1[0], b2[0]), min(b1[1], b2[1]),
+ max(b1[2], b2[2]), max(b1[3], b2[3])
+ )
+ else:
+ return (
+ min(b1[0], b2[0]), min(b1[1], b2[1]), min(b1[2], b2[2]),
+ max(b1[3], b2[3]), max(b1[4], b2[4]), max(b1[5], b2[5])
+ )
+
+ bbox = None
+ if custom_bbox or use_bbox:
+ if not custom_bbox and use_bbox and specific_IDs is not None:
+ rp_old = posData.rp
+ for ID in specific_IDs:
+ b = rp_old.get_obj_from_ID(ID).bbox
+ bbox = b if bbox is None else merge_bbox(bbox, b)
+ else:
+ bbox = custom_bbox
+
+ if bbox is None:
+ return False
+
+ elif use_curr_view:
+ cutout = self.ax1ViewRange(integers=True)
+ cutout = self._get_entire_depth_axis_from_2D_cutout(cutout)
+ if len(cutout)==2:
+ (xl, xr), (yt, yb) = cutout
+ else:
+ (z1, z2), (xl, xr), (yt, yb) = cutout
+ z_min = min(z1, z2)
+ z_max = max(z1, z2)
+ x_min = min(xl, xr)
+ x_max = max(xl, xr)
+ y_min = min(yt, yb)
+ y_max = max(yt, yb)
+ bbox = (y_min, x_min, y_max, x_max) if len(cutout)==2 else (z_min, y_min, x_min, z_max, y_max, x_max)
+ # Expand bbox by a fraction
+ else:
+ raise ValueError('''Either custom_bbox or use_bbox or use_curr_view must be provided as True''')
+
+ if len(bbox) == 4:
+ y_min, x_min, y_max, x_max = bbox
+ offset_y = int((y_max - y_min) * add_frac_custom_bbox)
+ offset_x = int((x_max - x_min) * add_frac_custom_bbox)
+ offset_y = 1 if offset_y == 0 else offset_y
+ offset_x = 1 if offset_x == 0 else offset_x
+ size_y, size_x = posData.SizeY, posData.SizeX
+ cutout = (
+ (max(0, x_min - offset_x), min(size_x, x_max + offset_x)),
+ (max(0, y_min - offset_y), min(size_y, y_max + offset_y))
+ )
+ else:
+ z_min, y_min, x_min, z_max, y_max, x_max = bbox
+ offset_z = int((z_max - z_min) * add_frac_custom_bbox)
+ offset_y = int((y_max - y_min) * add_frac_custom_bbox)
+ offset_x = int((x_max - x_min) * add_frac_custom_bbox)
+ offset_z = 1 if offset_z == 0 else offset_z
+ offset_y = 1 if offset_y == 0 else offset_y
+ offset_x = 1 if offset_x == 0 else offset_x
+ size_z, size_y, size_x = posData.SizeZ, posData.SizeY, posData.SizeX
+ cutout = (
+ (max(0, z_min - offset_z), min(size_z, z_max + offset_z)),
+ (max(0, y_min - offset_y), min(size_y, y_max + offset_y)),
+ (max(0, x_min - offset_x), min(size_x, x_max + offset_x))
+ )
+
+ perc_from_global = self._get_perc_cutout_from_total_img(cutout)
+ if perc_from_global > RP_OPT_PERC_CUTOUT_MAX:
+ return False
+ return self._cutout_to_bbox(cutout)
@exception_handler
def update_rp(
- self, draw=True, debug=False, update_IDs=True,
- wl_update=True, wl_track_og_curr=False,wl_update_lab=False
+ self, draw=True, debug=False, # og stuff
+ assignments=None, deletionIDs=None, # very quick upates, rp labels are changed but rest is same
+ specific_IDs=None, use_curr_view=False, use_bbox=False, preloaded_bbox=None, # for local updates to PR
+ wl_update=True, wl_track_og_curr=False,wl_update_lab=False, # wl stuff
):
+ """Updates posData.rp
+ Parameters
+ ----------
+
+ """
+ #updating rp is very clostly, as it deletes all the cashed
+ if use_curr_view and use_bbox:
+ raise ValueError('''use_curr_view and use_bbox cannot be True at the
+ same time, as they are mutually exclusive''')
+ local_rp_update = bool(use_curr_view or use_bbox or preloaded_bbox)
posData = self.data[self.pos_i]
# Update rp for current posData.lab (e.g. after any change)
-
if wl_update:
if self.whitelistOriginalIDs is None:
- old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy() # for whitelist stuff
+ old_IDs = posData.allData_li[posData.frame_i]['regionprops'].IDs.copy() # for whitelist stuff
else:
old_IDs = self.whitelistOriginalIDs.copy()
self.whitelistOriginalIDs = None
elif self.whitelistOriginalIDs is None:
- self.whitelist_old_IDs = posData.allData_li[posData.frame_i]['IDs'].copy()
-
- posData.rp = skimage.measure.regionprops(posData.lab)
- if update_IDs:
- IDs = []
- IDs_idxs = {}
- for idx, obj in enumerate(posData.rp):
- IDs.append(obj.label)
- IDs_idxs[obj.label] = idx
- posData.IDs = IDs
- posData.IDs_idxs = IDs_idxs
+ self.whitelist_old_IDs = (
+ posData.allData_li[posData.frame_i]['regionprops'].IDs.copy())
+
+ # check if only one of assignments, deletionIDs or only_current_view is given
+ if sum([assignments is not None,
+ deletionIDs is not None,
+ local_rp_update,
+ ]) > 1:
+ print(assignments is not None, deletionIDs is not None, local_rp_update)
+ raise ValueError('Only one of assignments, deletionIDs, '
+ 'use_curr_view or use_bbox, preloaded_bbox can be used '
+ 'at a time')
+
+ if not isinstance(specific_IDs, (list, set, np.ndarray)) and specific_IDs is not None:
+ specific_IDs = [specific_IDs]
+ elif specific_IDs is not None and len(specific_IDs) == 0:
+ specific_IDs = None
+
+
+ # posData.rp is an acdcRegionprops instance here.
+ # if rp is None (can sometimes happen appearantly???)
+ if posData.rp is None:
+ printl(f'''Warning: posData.rp is None for pos {self.pos_i},
+ frame {posData.frame_i}. Recomputing rp from labels.''')
+
+ posData.rp = regionprops.acdcRegionprops(
+ posData.lab, precache_centroids=False
+ )
+
+ if assignments is not None:
+ # {old_ID: new_ID, ...}
+ posData.rp.update_regionprops_via_assignments(assignments, posData.lab)
+ elif deletionIDs is not None:
+ # (delID1, delID2, ...)
+ posData.rp.update_regionprops_via_deletions(deletionIDs)
+ elif local_rp_update:
+ # first get current view
+ if preloaded_bbox is None:
+ preloaded_bbox = self.update_rp_get_bbox(use_bbox=use_bbox, use_curr_view=use_curr_view,
+ specific_IDs=specific_IDs)
+ if preloaded_bbox is not False:
+ posData.rp.update_regionprops_via_cutout(
+ posData.lab, cutout_bbox=preloaded_bbox, specific_IDs=specific_IDs
+ )
+ # if ID touches border but is not in specific_IDs, it will not be updated,
+ # so be careful!
+ else:
+ posData.rp.update_regionprops(
+ posData.lab
+ )
+ else:
+ posData.rp.update_regionprops(
+ posData.lab,
+ specific_IDs_update_centroids=specific_IDs if preloaded_bbox is not False else None, # since sometimes I preload
+ )
+ posData.IDs = posData.rp.IDs
+
self.update_rp_metadata(draw=draw)
- self.store_zslices_rp(force_update=True)
if not wl_update:
return
@@ -23052,12 +23618,12 @@ def updateTempLayerKeepIDs(self):
def highlightLabelID(self, ID, ax=0):
posData = self.data[self.pos_i]
- try:
- obj = posData.rp[posData.IDs_idxs[ID]]
- except KeyError:
+ obj = posData.rp.get_obj_from_ID(ID, warn=False)
+ if obj is None:
return
- self.textAnnot[ax].highlightObject(obj)
+ self.textAnnot[ax].highlightObject(
+ obj, rp=posData.rp, getObjCentroidFunc=self.getObjCentroid)
def _keepObjects(self, keepIDs=None, lab=None, rp=None):
posData = self.data[self.pos_i]
@@ -23087,7 +23653,7 @@ def removeHighlightLabelID(self, IDs=None, ax=0):
IDs = posData.IDs
for ID in IDs:
- obj = posData.rp[posData.IDs_idxs[ID]]
+ obj = posData.rp.get_obj_from_ID(ID)
self.textAnnot[ax].removeHighlightObject(obj)
def updateKeepIDs(self, IDs):
@@ -23241,7 +23807,9 @@ def applyKeepObjects(self):
elif includeUnvisited:
# Unvisited frame (includeUnvisited = True)
lab = posData.segm_data[i]
- rp = skimage.measure.regionprops(lab)
+ rp = regionprops.acdcRegionprops(
+ lab, precache_centroids=False
+ )
keepLab = self._keepObjects(lab=lab, rp=rp)
posData.segm_data[i] = keepLab
@@ -23323,7 +23891,8 @@ def annotate_rip_and_bin_IDs(self, updateLabel=False):
continue
if obj.excluded:
- y, x = self.getObjCentroid(obj.centroid)
+ ID = obj.label
+ y, x = self.getObjCentroid(posData.rp.get_centroid(ID))
binnedIDs_xx.append(x)
binnedIDs_yy.append(y)
if updateLabel:
@@ -23331,7 +23900,8 @@ def annotate_rip_and_bin_IDs(self, updateLabel=False):
how = self.drawIDsContComboBox.currentText()
if obj.dead:
- y, x = self.getObjCentroid(obj.centroid)
+ ID = obj.label
+ y, x = self.getObjCentroid(posData.rp.get_centroid(ID))
ripIDs_xx.append(x)
ripIDs_yy.append(y)
if updateLabel:
@@ -24011,10 +24581,8 @@ def zoomToObj(self, obj=None):
posData = self.data[self.pos_i]
if obj is None:
ID = self.sender().value()
- try:
- ID_idx = posData.IDs_idxs[ID]
- obj = obj = posData.rp[ID_idx]
- except Exception as e:
+ obj = posData.rp.get_obj_from_ID(ID, warn=False)
+ if obj is None:
self.logger.warning(
f'ID {ID} does not exist (add points by clicking)'
)
@@ -24066,7 +24634,7 @@ def pointsLayerAutoPilot(self, direction):
return
try:
- ID_idx = posData.IDs_idxs[ID]
+ ID_idx = posData.rp.ID_to_idx[ID]
if direction == 'next':
nextID_idx = ID_idx + 1
else:
@@ -24194,7 +24762,7 @@ def checkLoadedTableIds(self, toolbar):
for posData in self.data:
for tableEndName, df in posData.clickEntryPointsDfs.items():
for point_id in df['id'].values:
- if point_id in posData.IDs_idxs:
+ if point_id in posData.rp.IDs:
proceed = self.warnAddingPointWithExistingId(
point_id, table_endname=tableEndName
)
@@ -24428,10 +24996,10 @@ def setHoverCircleAddPoint(self, x, y):
def isPointIdAlreadyNew(self, point_id, action):
posData = self.data[self.pos_i]
- if point_id in posData.IDs_idxs:
+ if point_id in posData.rp.IDs:
return False
- is_ID = point_id in posData.IDs_idxs
+ is_ID = point_id in posData.rp.IDs
pointsDataPos = action.pointsData.get(self.pos_i)
if pointsDataPos is None:
return not is_ID
@@ -24577,6 +25145,8 @@ def getCentroidsPointsData(self, action):
# Centroids (either weighted or not)
# NOTE: if user requested to draw from table we load that in
# apps.AddPointsLayerDialog.ok_cb()
+
+ # this does not have the updated centroid logic to avoid weird behaviours
posData = self.data[self.pos_i]
action.pointsData[self.pos_i] = {posData.frame_i: {}}
if hasattr(action, 'weighingData'):
@@ -25170,30 +25740,42 @@ def initTextAnnot(self, force=False):
Y, X = posData.img_data.shape[-2:]
self.textAnnot[0].initItem((Y, X))
self.textAnnot[1].initItem((Y, X))
-
+
+ def _get_obj_for_current_view_rp(self, obj, posData):
+ # 2D segmentation already has the correct regionprops object.
+ if not self.isSegm3D or not posData.rp.is3D:
+ return obj
+
+ slicing = self.switchPlaneCombobox.depthAxes()
+ zProjHow = self.zProjComboBox.currentText()
+ if zProjHow == 'single z-slice':
+ slice_selector = self.z_lab()
+ slice_number = slice_selector[-1] if isinstance(slice_selector, tuple) else slice_selector
+ obj_current_view = posData.rp.get_obj_from_slice_rp(
+ obj.label, slice_number, slicing=slicing, warn=False
+ )
+ return obj_current_view or obj
+
+ obj_current_view = posData.rp.get_obj_from_proj_rp(
+ obj.label, kind='most_common', slicing=slicing, warn=False
+ )
+ return obj_current_view or obj
+
def getObjContours(
- self, obj, all_external=False, local=False, force_calc=True,
- include_internal=False
+ self, obj, all_external=False, local=False,
+ include_internal=False, rp=None
):
posData = self.data[self.pos_i]
- dataDict = posData.allData_li[posData.frame_i]
- allContours = dataDict.get('contours')
- if allContours is not None and not force_calc:
- z = self.z_lab()
- key = (obj.label, str(z), all_external, local)
- contours = allContours.get(key)
- if contours is not None:
- return contours
+ obj_to_use = self._get_obj_for_current_view_rp(obj, posData)
- obj_image = self.getObjImage(obj.image, obj.bbox).astype(np.uint8)
- obj_bbox = self.getObjBbox(obj.bbox)
try:
contours = core.get_obj_contours(
- obj_image=obj_image,
- obj_bbox=obj_bbox,
+ obj=obj_to_use,
local=local,
- all_external=all_external
+ all_external=all_external,
+ all=include_internal
)
+
except Exception as e:
if all_external:
contours = []
@@ -25202,168 +25784,9 @@ def getObjContours(
self.logger.warning(
f'Object ID {obj.label} contours drawing failed. '
f'(bounding box = {obj.bbox})'
+ f'Error: {e}'
)
return contours
-
- def clearComputedContours(self):
- for posData in self.data:
- for frame_i, dataDict in enumerate(posData.allData_li):
- dataDict['contours'] = {}
-
- def _computeAllContours2D(
- self, dataDict, obj, z, obj_bbox, include_internal=False
- ):
- obj_image = self.getObjImage(obj.image, obj.bbox, z_slice=z)
- if obj_image is None:
- return
-
- all_external = False
- local = False
- contours = core.get_obj_contours(
- obj_image=obj_image,
- obj_bbox=obj_bbox,
- local=local,
- all_external=all_external
- )
- key = (obj.label, str(z), all_external, local)
- dataDict['contours'][key] = contours
-
- all_external = True
- local = False
- contours = core.get_obj_contours(
- obj_image=obj_image,
- obj_bbox=obj_bbox,
- local=local,
- all_external=all_external,
- all=include_internal
- )
- key = (obj.label, str(z), all_external, local)
- dataDict['contours'][key] = contours
-
- return dataDict
-
- def computeAllContours(self):
- self.logger.info('Computing all contours...')
- posData = self.data[self.pos_i]
- zz = [None]
- if self.isSegm3D:
- zz.extend(range(posData.SizeZ))
-
- include_internal = self.showAllContoursToggle.isChecked()
- for frame_i, dataDict in enumerate(posData.allData_li):
- lab = dataDict['labels']
- if lab is None:
- break
-
- rp = dataDict['regionprops']
- if rp is None:
- rp = skimage.measure.regionprops(lab)
-
- dataDict['contours'] = {}
- for obj in rp:
- obj_bbox = self.getObjBbox(obj.bbox)
- for z in zz:
- if not self.isObjVisible(obj.bbox, z_slice=z):
- continue
-
- try:
- self._computeAllContours2D(
- dataDict, obj, z, obj_bbox,
- include_internal=include_internal
- )
- except Exception as err:
- # Contours computation fails on weird objects
- pass
-
- def computeAllObjToObjCostPairs(self):
- desc = (
- 'Computing all object-to-object cost matrices...'
- )
- self.logger.info(desc)
- posData = self.data[self.pos_i]
-
-
- self.progressWin = apps.QDialogWorkerProgress(
- title=desc, parent=self, pbarDesc=desc
- )
- self.progressWin.mainPbar.setMaximum(0)
- self.progressWin.show(self.app)
-
- self.computeAllObjCostPairsThread = QThread()
- self.computeAllObjCostPairsWorker = workers.SimpleWorker(
- posData, self._computeAllObjToObjCostPairs
- )
-
- self.computeAllObjCostPairsWorker.moveToThread(
- self.computeAllObjCostPairsThread
- )
-
- self.computeAllObjCostPairsWorker.signals.finished.connect(
- self.computeAllObjCostPairsThread.quit
- )
- self.computeAllObjCostPairsWorker.signals.finished.connect(
- self.computeAllObjCostPairsWorker.deleteLater
- )
- self.computeAllObjCostPairsThread.finished.connect(
- self.computeAllObjCostPairsThread.deleteLater
- )
-
- self.computeAllObjCostPairsWorker.signals.critical.connect(
- self.computeAllObjCostPairsWorkerCritical
- )
- self.computeAllObjCostPairsWorker.signals.initProgressBar.connect(
- self.workerInitProgressbar
- )
- self.computeAllObjCostPairsWorker.signals.progressBar.connect(
- self.workerUpdateProgressbar
- )
- self.computeAllObjCostPairsWorker.signals.progress.connect(
- self.workerProgress
- )
- self.computeAllObjCostPairsWorker.signals.finished.connect(
- self.computeAllObjCostPairsWorkerFinished
- )
-
- self.computeAllObjCostPairsThread.started.connect(
- self.computeAllObjCostPairsWorker.run
- )
- self.computeAllObjCostPairsThread.start()
-
- self.computeAllObjCostPairsWorkerLoop = QEventLoop()
- self.computeAllObjCostPairsWorkerLoop.exec_()
-
- def _computeAllObjToObjCostPairs(self, posData):
- self.computeAllObjCostPairsWorker.signals.initProgressBar.emit(
- len(posData.allData_li)
- )
- for frame_i, dataDict in enumerate(posData.allData_li):
- if frame_i == 0:
- continue
-
- rp = dataDict['regionprops']
- if rp is None:
- break
-
- prev_rp = posData.allData_li[frame_i-1]['regionprops']
- dist_matrix = core._compute_all_obj_to_obj_contour_dist_pairs(
- dataDict['contours'], rp,
- prev_rp=prev_rp,
- restrict_search=True
- )
- dataDict['obj_to_obj_dist_cost_matrix_df'] = dist_matrix
- self.computeAllObjCostPairsWorker.signals.progressBar.emit(1)
- self.computeAllObjCostPairsWorker.signals.initProgressBar.emit(0)
-
- def computeAllObjCostPairsWorkerCritical(self, error):
- self.computeAllObjCostPairsWorkerLoop.exit()
- self.workerCritical(error)
-
- def computeAllObjCostPairsWorkerFinished(self, output):
- if self.progressWin is not None:
- self.progressWin.workerFinished = True
- self.progressWin.close()
- self.progressWin = None
- self.computeAllObjCostPairsWorkerLoop.exit()
def setOverlaySegmMasks(self, force=False, forceIfNotActive=False):
if not hasattr(self, 'currentLab2D'):
@@ -25396,27 +25819,21 @@ def setOverlaySegmMasks(self, force=False, forceIfNotActive=False):
self.extendLabelsLUT(maxID+10)
currentLab2D = self.currentLab2D
+ if (
+ self.isSegm3D
+ and self.zProjComboBox.currentText() != 'single z-slice'
+ and getattr(posData, 'rp', None) is not None
+ and posData.rp.is3D
+ ):
+ slicing = self.switchPlaneCombobox.depthAxes()
+ currentLab2D = posData.rp.get_projection_lab_sorted(slicing=slicing)
+ self.currentLab2D = currentLab2D
+
if isOverlaySegmLeftActive:
self.labelsLayerImg1.setImage(currentLab2D, autoLevels=False)
if isOverlaySegmRightActive:
self.labelsLayerRightImg.setImage(currentLab2D, autoLevels=False)
-
- def getObject2DimageFromZ(self, z, obj):
- posData = self.data[self.pos_i]
- z_min = obj.bbox[0]
- local_z = z - z_min
- if local_z >= posData.SizeZ or local_z < 0:
- return
- return obj.image[local_z]
-
- def getObject2DsliceFromZ(self, z, obj):
- posData = self.data[self.pos_i]
- z_min = obj.bbox[0]
- local_z = z - z_min
- if local_z >= posData.SizeZ or local_z < 0:
- return
- return obj.image[local_z]
def isObjVisible(self, obj_bbox, debug=False, z_slice=None):
if z_slice is None:
@@ -25457,13 +25874,12 @@ def getObjImage(self, obj_image, obj_bbox, z_slice=None):
# required a projection
return obj_image.max(axis=0)
- min_z = obj_bbox[0]
if z_slice is None:
z_slice = self.z_lab()
if isinstance(z_slice, tuple):
z_slice = z_slice[-1]
- local_z = z_slice - min_z
+ local_z = z_slice - obj_bbox[0]
try:
obi_image_2d = obj_image[local_z]
except Exception as err:
@@ -26338,6 +26754,7 @@ def applyDelROIs(self):
self.get_data()
def initTempLayerBrush(self, ID, ax=0):
+ posData = self.data[self.pos_i]
if ax == 0:
how = self.drawIDsContComboBox.currentText()
else:
@@ -26347,7 +26764,24 @@ def initTempLayerBrush(self, ID, ax=0):
Y, X = self.img1.image.shape[:2]
tempImage = np.zeros((Y, X), dtype=np.uint32)
if how.find('contours') != -1:
- tempImage[self.currentLab2D==ID] = ID
+ # Keep the currently edited object visible while painting.
+ obj = None
+ rp = getattr(posData, 'rp', None)
+ if rp is not None:
+ obj = rp.get_obj_from_ID(ID, warn=False)
+ if obj is not None:
+ obj = self._get_obj_for_current_view_rp(obj, posData)
+
+ if (
+ obj is not None
+ and hasattr(obj, 'slice')
+ and hasattr(obj, 'image')
+ and len(obj.slice) == 2
+ and obj.image.ndim == 2
+ ):
+ tempImage[obj.slice][obj.image] = ID
+ else:
+ tempImage[self.currentLab2D == ID] = ID
self.brushImage = tempImage.copy()
self.brushContourImage = np.zeros((Y, X, 4), dtype=np.uint8)
color = self.imgGrad.contoursColorButton.color()
@@ -26383,6 +26817,7 @@ def setTempBrushMaskFromWand(self, flood_mask, init=False):
# @exec_time
def setTempImg1Brush(self, init: bool, mask, ID, toLocalSlice=None, ax=0):
+ posData = self.data[self.pos_i]
if init:
self.initTempLayerBrush(ID, ax=ax)
@@ -26397,14 +26832,11 @@ def setTempImg1Brush(self, init: bool, mask, ID, toLocalSlice=None, ax=0):
brushImage[toLocalSlice][mask] = ID
if self.annotContourCheckbox.isChecked():
- try:
- obj = skimage.measure.regionprops(brushImage)[0]
- except IndexError:
- return
- objContour = [self.getObjContours(obj)]
- # objContour = core.get_obj_contours(
- # obj_image=(brushImage>0).astype(np.uint8), local=True
- # )
+ brushMask = np.ascontiguousarray((brushImage > 0), dtype=np.uint8)
+
+ objContour = core.get_obj_contours(
+ obj_image=brushMask, obj_bbox=None, all_external=True, local=True
+ )
self.brushContourImage[:] = 0
img = self.brushContourImage
color = self.brushContoursRgba
@@ -26447,9 +26879,22 @@ def setTempImg1Eraser(self, mask, init=False, toLocalSlice=None, ax=0):
self.clearObjFromMask(
self.contoursImage, mask, toLocalSlice=toLocalSlice
)
- erasedRp = skimage.measure.regionprops(self.erasedLab)
- for obj in erasedRp:
- self.addObjContourToContoursImage(obj=obj, ax=ax)
+ thickness = self.contLineWeight
+ color = self.contLineColor
+ for erasedID in np.unique(self.erasedLab):
+ if erasedID == 0:
+ continue
+ erasedMask = np.ascontiguousarray(
+ (self.erasedLab == erasedID), dtype=np.uint8
+ )
+ contours = core.get_obj_contours(
+ obj_image=erasedMask, obj_bbox=None,
+ all_external=True, local=True
+ )
+ cv2.drawContours(self.contoursImage, contours, -1, color, thickness)
+ imageItem = self.getContoursImageItem(ax)
+ if imageItem is not None:
+ imageItem.setImage(self.contoursImage)
elif how.find('overlay segm. masks') != -1:
labelsImage = self.getLabelsLayerImage(ax=ax)
self.clearObjFromMask(labelsImage, mask, toLocalSlice=toLocalSlice)
@@ -26462,13 +26907,13 @@ def setTempImg1Eraser(self, mask, init=False, toLocalSlice=None, ax=0):
self.labelsLayerRightImg.image, autoLevels=False
)
- def _setTempImgExpandLabelSegmMasks(self, prevCoords, ax=0):
+ def _setTempImgExpandLabelSegmMasks(self, prevCoords, expandedObjCoords, ax=0):
# Remove previous overlaid mask
labelsImage = self.getLabelsLayerImage(ax=ax)
labelsImage[prevCoords] = 0
- # Overlay new moved mask
- labelsImage[prevCoords] = self.expandingID
+ # Overlay new expanded mask
+ labelsImage[expandedObjCoords] = self.expandingID
if ax == 0:
self.labelsLayerImg1.setImage(
@@ -26479,25 +26924,31 @@ def _setTempImgExpandLabelSegmMasks(self, prevCoords, ax=0):
def _setTempImgExpandLabelContours(self, prevCoords, ax=0):
self.contoursImage[prevCoords] = [0,0,0,0]
- currentLab2Drp = skimage.measure.regionprops(self.currentLab2D)
- for obj in currentLab2Drp:
- if obj.label == self.expandingID:
- # self.clearObjContour(obj=obj, ax=ax)
- self.addObjContourToContoursImage(obj=obj, ax=ax, force=True)
- break
+ expandMask = np.ascontiguousarray(
+ (self.currentLab2D == self.expandingID), dtype=np.uint8
+ )
+ contours = core.get_obj_contours(
+ obj_image=expandMask, obj_bbox=None, all_external=True, local=True
+ )
+ imageItem = self.getContoursImageItem(ax, force=True)
+ if imageItem is not None:
+ thickness = self.contLineWeight
+ color = self.contLineColor
+ cv2.drawContours(self.contoursImage, contours, -1, color, thickness)
+ imageItem.setImage(self.contoursImage)
def setTempImgExpandLabel(self, prevCoords, expandedObjCoords, ax=0):
if ax == 0:
how = self.drawIDsContComboBox.currentText()
else:
how = self.getAnnotateHowRightImage()
-
- self._setTempImgExpandLabelContours(prevCoords, ax=ax)
-
- # if how.find('overlay segm. masks') != -1:
- # self._setTempImgExpandLabelSegmMasks(ax=ax)
- # else:
- # self._setTempImgExpandLabelContours(ax=ax)
+
+ if how.find('overlay segm. masks') != -1:
+ self._setTempImgExpandLabelSegmMasks(
+ prevCoords, expandedObjCoords, ax=ax
+ )
+ else:
+ self._setTempImgExpandLabelContours(prevCoords, ax=ax)
def setTempImg1MoveLabel(self, ax=0):
if ax == 0:
@@ -26506,11 +26957,18 @@ def setTempImg1MoveLabel(self, ax=0):
how = self.getAnnotateHowRightImage()
if how.find('contours') != -1:
- currentLab2Drp = skimage.measure.regionprops(self.currentLab2D)
- for obj in currentLab2Drp:
- if obj.label == self.movingID:
- self.addObjContourToContoursImage(obj=obj, ax=ax)
- break
+ moveMask = np.ascontiguousarray(
+ (self.currentLab2D == self.movingID), dtype=np.uint8
+ )
+ contours = core.get_obj_contours(
+ obj_image=moveMask, obj_bbox=None, all_external=True, local=True
+ )
+ imageItem = self.getContoursImageItem(ax)
+ if imageItem is not None:
+ thickness = self.contLineWeight
+ color = self.contLineColor
+ cv2.drawContours(self.contoursImage, contours, -1, color, thickness)
+ imageItem.setImage(self.contoursImage)
elif how.find('overlay segm. masks') != -1:
if ax == 0:
self.labelsLayerImg1.setImage(self.currentLab2D, autoLevels=False)
@@ -26611,7 +27069,8 @@ def updateCcaDfDeletedIDsTimelapse(
else:
for delID in deletedIDs:
dataDict = posData.allData_li[fut_frame_i]
- delIDexists = dataDict['IDs_idxs'].get(delID, False)
+ rp = dataDict['regionprops']
+ delIDexists = delID in rp.IDs
if not delIDexists:
continue
@@ -26648,7 +27107,8 @@ def updateCcaDfDeletedIDsTimelapse(
else:
for delID in deletedIDs:
dataDict = posData.allData_li[past_frame_i]
- delIDexists = dataDict['IDs_idxs'].get(delID, False)
+ rp = dataDict['regionprops']
+ delIDexists = delID in rp.IDs
if not delIDexists:
continue
@@ -26944,8 +27404,7 @@ def highlightHoverID(self, x, y, hoverID=None):
return
posData = self.data[self.pos_i]
- objIdx = posData.IDs_idxs[hoverID]
- obj = posData.rp[objIdx]
+ obj = posData.rp.get_obj_from_ID(hoverID)
self.goToZsliceSearchedID(obj)
self.highlightSearchedID(hoverID)
@@ -27009,12 +27468,10 @@ def highlightHoverIDsKeptObj(self, x, y, hoverID=None):
return
posData = self.data[self.pos_i]
- try:
- objIdx = posData.IDs_idxs[hoverID]
- except KeyError as err:
- return
+ obj = posData.rp.get_obj_from_ID(hoverID, warn=False)
+ if obj is None:
+ return
- obj = posData.rp[objIdx]
self.goToZsliceSearchedID(obj)
for ID in self.keptObjectsIDs:
@@ -27079,11 +27536,10 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True):
self.highlightedID = ID
self.highlightIDToolbar.setVisible(True)
- objIdx = posData.IDs_idxs.get(ID)
- if objIdx is None:
+ obj = posData.rp.get_obj_from_ID(ID, warn=False)
+ if obj is None:
return
- obj = posData.rp[objIdx]
isObjVisible = self.isObjVisible(obj.bbox)
if not isObjVisible:
return
@@ -27111,7 +27567,11 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True):
self.highLightIDLayerImg1.setImage(self.highlightedLab)
self.labelsLayerImg1.setOpacity(alpha/3)
else:
- contours = self.getObjContours(obj, all_external=True)
+ contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for cont in contours:
self.searchedIDitemLeft.addPoints(cont[:,0]+0.5, cont[:,1]+0.5)
@@ -27121,7 +27581,11 @@ def highlightSearchedID(self, ID, force=False, greyOthers=True):
self.labelsLayerRightImg.setOpacity(alpha/3)
else:
if contours is None:
- contours = self.getObjContours(obj, all_external=True)
+ contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for cont in contours:
self.searchedIDitemRight.addPoints(cont[:,0]+0.5, cont[:,1]+0.5)
@@ -27314,8 +27778,14 @@ def setManualBackgroundImage(self):
self.initManualBackgroundImage()
contours = []
- for obj in skimage.measure.regionprops(posData.manualBackgroundLab):
- obj_contours = self.getObjContours(obj, all_external=True)
+ for obj in regionprops.acdcRegionprops(
+ posData.manualBackgroundLab, precache_centroids=False
+ ):
+ obj_contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
contours.extend(obj_contours)
textItem = self.manualBackgroundTextItems[obj.label]
textItem.setText(f'{obj.label}')
@@ -27331,7 +27801,7 @@ def setManualBackgroundImage(self):
def setManualBackgrounNextID(self):
posData = self.data[self.pos_i]
currentID = self.manualBackgroundObj.label
- idx = posData.IDs_idxs[currentID]
+ idx = posData.rp.ID_to_idx[currentID]
next_idx = idx + 1
if next_idx >= len(posData.IDs):
return
@@ -27390,7 +27860,9 @@ def setManualBackgroundLab(self, load_from_store=False, debug=True):
if posData.manualBackgroundLab is None:
self.initManualBackgroundImage()
- for obj in skimage.measure.regionprops(posData.manualBackgroundLab):
+ for obj in regionprops.acdcRegionprops(
+ posData.manualBackgroundLab, precache_centroids=False
+ ):
textItem = pg.TextItem(text='', color='r', anchor=(0.5, 0.5))
if obj.label in self.manualBackgroundTextItems:
continue
@@ -27407,13 +27879,26 @@ def updateContoursImage(self, ax, delROIsIDs=None, compute=True):
self.contoursImage[:] = 0
contours = []
- for obj in skimage.measure.regionprops(self.currentLab2D):
+ posData = self.data[self.pos_i]
+ rp = posData.rp
+ use_local_rp = (
+ self.isSegm3D
+ and self.zProjComboBox.currentText() == 'single z-slice'
+ )
+ if rp is None:
+ lab = self.currentLab2D
+ rp = regionprops.acdcRegionprops(lab, precache_centroids=False)
+ if not use_local_rp:
+ posData.rp = rp
+ elif use_local_rp and rp.is3D:
+ rp = self.rpCurr2D()
+
+ for obj in rp:
obj_contours = self.getObjContours(
- obj,
- all_external=True,
- force_calc=compute,
+ obj,
+ all_external=True,
include_internal=self.showAllContoursToggle.isChecked()
- )
+ )
contours.extend(obj_contours)
thickness = self.contLineWeight
@@ -27426,17 +27911,18 @@ def setContoursImage(self, imageItem, contours, thickness, color):
def getObjFromID(self, ID):
posData = self.data[self.pos_i]
- try:
- idx = posData.IDs_idxs[ID]
- except KeyError as e:
+ obj = posData.rp.get_obj_from_ID(ID, warn=False)
+ if obj is None:
# Object already cleared
return
-
- obj = posData.rp[idx]
return obj
def setLostObjectContour(self, obj):
- allContours = self.getObjContours(obj, all_external=True)
+ allContours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for objContours in allContours:
xx = objContours[:,0] + 0.5
yy = objContours[:,1] + 0.5
@@ -27448,7 +27934,11 @@ def setTrackedLostObjectContour(self, obj):
if self.isExportingVideo:
return
- allContours = self.getObjContours(obj, all_external=True)
+ allContours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for objContours in allContours:
xx = objContours[:,0] + 0.5
yy = objContours[:,1] + 0.5
@@ -27472,7 +27962,6 @@ def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None):
posData = self.data[self.pos_i]
prev_rp = posData.allData_li[posData.frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs']
if posData.whitelist is not None and posData.whitelist.whitelistIDs is not None:
whitelist = posData.whitelist.whitelistIDs[posData.frame_i-1]
else:
@@ -27483,11 +27972,15 @@ def updateLostContoursImage(self, ax, draw=True, delROIsIDs=None):
if lostID in delROIsIDs or (whitelist is not None and lostID not in whitelist):
continue
- obj = prev_rp[prev_IDs_idxs[lostID]]
+ obj = prev_rp.get_obj_from_ID(lostID)
if not self.isObjVisible(obj.bbox):
continue
- obj_contours = self.getObjContours(obj, all_external=True)
+ obj_contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
if ax == 0:
self.addLostObjsToLostObjImage(obj, lostID)
@@ -27528,17 +28021,20 @@ def updateLostTrackedContoursImage(
tracked_lost_IDs = self.getTrackedLostIDs()
prev_rp = posData.allData_li[posData.frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs']
contours = []
for tracked_lost_ID in tracked_lost_IDs:
if tracked_lost_ID in delROIsIDs:
continue
- obj = prev_rp[prev_IDs_idxs[tracked_lost_ID]]
+ obj = prev_rp.get_obj_from_ID(tracked_lost_ID)
if not self.isObjVisible(obj.bbox):
continue
- obj_contours = self.getObjContours(obj, all_external=True)
+ obj_contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
contours.extend(obj_contours)
self.drawLostTrackedObjContoursImage(imageItem, contours)
@@ -27596,13 +28092,20 @@ def getNearestLostObjID(self, y, x):
return nearest_ID
def setCcaIssueContour(self, obj):
- objContours = self.getObjContours(obj, all_external=True)
+ objContours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for cont in objContours:
xx = cont[:,0] + 0.5
yy = cont[:,1] + 0.5
self.ax1_lostObjScatterItem.addPoints(xx, yy)
+
+ posData = self.data[self.pos_i]
self.textAnnot[0].addObjAnnotation(
- obj, 'lost_object', f'{obj.label}?', False
+ obj, 'lost_object', f'{obj.label}?', False,
+ rp=posData.rp, getObjCentroidFunc=self.getObjCentroid
)
def isLastVisitedAgainCca(self, curr_df, enforceAll=False):
@@ -27650,7 +28153,8 @@ def highlightNewCellNotEnoughG1cells(self, IDsCellsG1):
yy = objContours[:,1] + 0.5
self.ccaFailedScatterItem.addPoints(xx, yy)
self.textAnnot[0].addObjAnnotation(
- obj, 'green', f'{obj.label}?', False
+ obj, 'green', f'{obj.label}?', False,
+ rp=posData.rp, getObjCentroidFunc=self.getObjCentroid
)
def handleNoCellsInG1(self, numCellsG1, numNewCells):
@@ -27688,7 +28192,11 @@ def addObjContourToContoursImage(
if obj is None:
return
- contours = self.getObjContours(obj, all_external=True)
+ contours = self.getObjContours(
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
if thickness is None:
thickness = self.contLineWeight
if color is None:
@@ -27747,9 +28255,7 @@ def setAllTextAnnotations(self, labelsToSkip=None):
return delROIsIDs
def setAllContoursImages(self, delROIsIDs=None, compute=True):
- if compute:
- self.computeAllContours()
- self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute)
+ self.updateContoursImage(ax=0, delROIsIDs=delROIsIDs, compute=compute) #almost all from here
self.updateContoursImage(ax=1, delROIsIDs=delROIsIDs, compute=compute)
def setAllLostObjContoursImage(self, delROIsIDs=None):
@@ -27875,7 +28381,6 @@ def keyDownCallback(
QAbstractSlider.SliderAction.SliderSingleStepSub
)
- # @exec_time
@exception_handler
def updateAllImages(
self, image=None, computePointsLayers=True, computeContours=True,
@@ -27957,22 +28462,19 @@ def deleteIDFromLab(
lab = self.get_2Dlab(lab)
if delMask is not None:
delMask = self.get_2Dlab(delMask)
- rp = skimage.measure.regionprops(lab)
- IDs_idxs = {obj.label: idx for idx, obj in enumerate(rp)}
- else:
+ rp = regionprops.acdcRegionprops(lab, precache_centroids=False)
+ else:
if frame_i==posData.frame_i:
rp = posData.rp
- IDs_idxs = posData.IDs_idxs
else:
rp = posData.allData_li[frame_i]['regionprops']
- IDs_idxs = posData.allData_li[frame_i]['IDs_idxs']
if isinstance(delID, int):
delID = [delID]
is_any_id_present = False
for _delID in delID:
- if _delID in IDs_idxs:
+ if _delID in rp.IDs:
is_any_id_present = True
break
@@ -27985,10 +28487,9 @@ def deleteIDFromLab(
delMask[:] = False
for _delID in delID:
- idx = IDs_idxs.get(_delID, None)
- if idx is None:
+ if _delID not in rp.IDs:
continue
- obj = rp[idx]
+ obj = rp.get_obj_from_ID(_delID)
delMask[obj.slice][obj.image] = True
lab[delMask] = 0
@@ -28001,31 +28502,6 @@ def deleteIDFromLab(
return lab, delMask
- def removeStoredContours(self, delID, frame_i=None, z_slice=None):
- posData = self.data[self.pos_i]
-
- if frame_i is None:
- frame_i = posData.frame_i
-
- dataDict = posData.allData_li[posData.frame_i]
- try:
- newContours = {}
- for key, contours in dataDict['contours'].items():
- ID = key[0]
- if ID == delID:
- continue
-
- if z_slice is not None:
- z_slice_i = key[1]
- if z_slice_i != z_slice:
- continue
-
- newContours[key] = contours
-
- dataDict['contours'] = newContours
- except KeyError as err:
- pass
-
@disableWindow
def deleteIDmiddleClick(
self, delIDs: Iterable, applyFutFrames, includeUnvisited,
@@ -28084,7 +28560,6 @@ def deleteIDmiddleClick(
self.clearObjContour(ID=_delID, ax=1)
if z_slice is None:
self.removeObjectFromRp(_delID)
- self.removeStoredContours(_delID, z_slice=z_slice)
if shift and self.isSegm3D:
self.update_rp()
@@ -28129,9 +28604,13 @@ def setOverlayLabelsItems(self, specific=None):
imageItem, contoursItem, gradItem = items
contoursItem.clear()
if drawMode == 'Draw contours':
- for obj in skimage.measure.regionprops(ol_lab):
+ for obj in regionprops.acdcRegionprops(
+ ol_lab, precache_centroids=False
+ ):
contours = self.getObjContours(
- obj, all_external=True
+ obj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
)
for cont in contours:
contoursItem.addPoints(cont[:,0]+0.5, cont[:,1]+0.5)
@@ -28294,9 +28773,12 @@ def highlightHoverLostObj(self, modifiers, event):
self.ax1_lostObjScatterItem.setData([], [])
else:
prev_rp = posData.allData_li[posData.frame_i-1]['regionprops']
- prev_IDs_idxs = posData.allData_li[posData.frame_i-1]['IDs_idxs']
- lostObj = prev_rp[prev_IDs_idxs[hoverLostID]]
- obj_contours = self.getObjContours(lostObj, all_external=True)
+ lostObj = prev_rp.get_obj_from_ID(hoverLostID)
+ obj_contours = self.getObjContours(
+ lostObj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for cont in obj_contours:
xx = cont[:,0]
yy = cont[:,1]
@@ -28316,10 +28798,12 @@ def getPrevFrameIDs(self, current_frame_i=None):
if current_frame_i is None:
return []
- prev_frame_i = current_frame_i - 1
- prevIDs = posData.allData_li[prev_frame_i]['IDs']
+ if current_frame_i == 0:
+ return []
- if prevIDs:
+ prev_frame_i = current_frame_i - 1
+ if posData.allData_li[prev_frame_i]['regionprops'] is not None:
+ prevIDs = posData.allData_li[prev_frame_i]['regionprops'].IDs
return prevIDs
# IDs in previous frame were not stored --> load prev lab from HDD
@@ -28328,8 +28812,11 @@ def getPrevFrameIDs(self, current_frame_i=None):
frame_i=prev_frame_i,
return_copy=False
)
- rp = skimage.measure.regionprops(prev_lab)
- prevIDs = [obj.label for obj in rp]
+ rp = regionprops.acdcRegionprops(
+ prev_lab, precache_centroids=False
+ )
+ posData.allData_li[prev_frame_i]['regionprops'] = rp
+ prevIDs = rp.IDs
return prevIDs
# @exec_time
@@ -28490,7 +28977,9 @@ def separateByLabelling(self, lab, rp, maxID=None):
maxID = max(posData.IDs, default=1)
for obj in rp:
lab_obj = skimage.measure.label(obj.image)
- rp_lab_obj = skimage.measure.regionprops(lab_obj)
+ rp_lab_obj = regionprops.acdcRegionprops(
+ lab_obj, precache_centroids=False
+ )
if len(rp_lab_obj)<=1:
continue
lab_obj += maxID
@@ -28544,105 +29033,86 @@ def trackManuallyAddedObject(
added_IDs = [added_IDs]
posData = self.data[self.pos_i]
- tracked_lab = self.tracking(
+ tracked_lab, assignments = self.tracking(
enforce=True, assign_unique_new_IDs=False, return_lab=True,
- IDs=added_IDs
+ specific_IDs=added_IDs, return_assignments=True,
+ against_next=posData.frame_i==0
)
+
+ # RP not updated after tracking!!!
self.clearAssignedObjsSecondStep()
if tracked_lab is None:
return
# Track only new object
- prevIDs = posData.allData_li[posData.frame_i-1]['IDs']
-
- # mask = np.zeros(posData.lab.shape, dtype=bool)
- update_rp = False
-
+ prevIDs = posData.allData_li[posData.frame_i-1]['regionprops'].IDs
+
+ # assignments_new = dict()
+ # self.update_rp(assignments=assignments)
for added_ID in added_IDs:
- # try:
- # obj = posData.rp[added_ID] # ID not present
- # mask[obj.slice][obj.image] = True
-
- # except IndexError as err:
- mask = posData.lab == added_ID
+
+ # check if added ID is already present
+ # here PR is "stale" so ID maps are not tracked
+ obj = posData.rp.get_obj_from_ID(added_ID, warn=False)
+ if obj is None:
+ continue
try:
- trackedID = tracked_lab[mask][0]
+ trackedID = tracked_lab[obj.slice][obj.image][0]
except IndexError as err:
# added_ID is not present
continue
isTrackedIDalreadyPresentAndNotNew = (
- posData.IDs_idxs.get(trackedID) is not None
+ posData.rp.ID_to_idx.get(trackedID) is not None
and added_ID != trackedID
)
if isTrackedIDalreadyPresentAndNotNew:
+ self.updatePointsLayerClickEntryTableEndname(
+ 'added obj already present', added_ID, trackedID
+ )
continue
isTrackedIDinPrevIDs = trackedID in prevIDs
if isTrackedIDinPrevIDs:
- posData.lab[mask] = trackedID
+ posData.lab[obj.slice][obj.image] = trackedID
else:
# New object where we can try to track against next frame
- trackedID = self.trackNewIDtoNewIDsFutureFrame(added_ID, mask)
+ trackedID, assignments = self.trackNewIDtoNewIDsFutureFrame(added_ID, obj, assignments)
if trackedID is None:
self.clearAssignedObjsSecondStep()
continue
- posData.lab[mask] = trackedID
+ posData.lab[obj.slice][obj.image] = trackedID
self.keepOnlyNewIDAssignedObjsSecondStep(trackedID)
- update_rp = True
- if update_rp:
- self.update_rp(wl_update=wl_update)
-
+ self.update_rp(wl_update=wl_update, assignments=assignments)
+
def trackFrameCustomTracker(
- self, prev_lab, currentLab, IDs=None, unique_ID=None
+ self, prev_lab, currentLab, specific_IDs=None, unique_ID=None,
+ return_assignments=True, dont_return_tracked_lab=False
):
if unique_ID is None:
unique_ID = self.setBrushID()
- try:
- tracked_result = self.realTimeTracker.track_frame(
- prev_lab, currentLab,
- unique_ID=unique_ID,
- IDs=IDs,
- **self.track_frame_params,
- )
- except TypeError as err:
- if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1:
- try:
- tracked_result = self.realTimeTracker.track_frame(
- prev_lab, currentLab, IDs=IDs,
- **self.track_frame_params
- )
- except TypeError as err:
- if str(err).find('an unexpected keyword argument \'IDs\'') != -1:
- tracked_result = self.realTimeTracker.track_frame(
- prev_lab, currentLab,
- **self.track_frame_params)
- else:
- raise err
- elif str(err).find('an unexpected keyword argument \'IDs\'') != -1:
- try:
- tracked_result = self.realTimeTracker.track_frame(
- prev_lab, currentLab,
- unique_ID=unique_ID,
- **self.track_frame_params
- )
- except TypeError as err:
- if str(err).find('an unexpected keyword argument \'unique_ID\'') != -1:
- tracked_result = self.realTimeTracker.track_frame(
- prev_lab, currentLab,
- **self.track_frame_params
- )
- else:
- raise err
- else:
- raise err
+
+ kwargs_total = {
+ 'unique_ID': unique_ID,
+ 'return_assignments': return_assignments,
+ 'dont_return_tracked_lab': dont_return_tracked_lab,
+ 'specific_IDs': specific_IDs
+ }
+ kwargs_total.update(self.track_frame_params)
+
+ kwargs = {k: v for k, v in kwargs_total.items() if k in self.realTimeTracker_kwargs}
+ tracked_result = self.realTimeTracker.track_frame(
+ prev_lab, currentLab,
+ **kwargs,
+ )
return tracked_result
def trackFrame(
self, prev_lab, prev_rp, curr_lab, curr_rp, curr_IDs,
- assign_unique_new_IDs=True, IDs=None, unique_ID=None
+ assign_unique_new_IDs=True, specific_IDs=None, unique_ID=None,
+ dont_return_tracked_lab=False, return_assignments=False,
):
if self.trackWithAcdcAction.isChecked():
tracked_result = CellACDC_tracker.track_frame(
@@ -28651,8 +29121,10 @@ def trackFrame(
setBrushID_func=self.setBrushID,
posData=self.data[self.pos_i],
assign_unique_new_IDs=assign_unique_new_IDs,
- IDs=IDs,
- unique_ID=unique_ID
+ specific_IDs=specific_IDs,
+ unique_ID=unique_ID,
+ return_assignments=return_assignments,
+ dont_return_tracked_lab=dont_return_tracked_lab
)
elif self.trackWithYeazAction.isChecked():
tracked_result = self.tracking_yeaz.correspondence(
@@ -28661,17 +29133,43 @@ def trackFrame(
)
else:
tracked_result = self.trackFrameCustomTracker(
- prev_lab, curr_lab, IDs=IDs, unique_ID=unique_ID
+ prev_lab, curr_lab, specific_IDs=specific_IDs, unique_ID=unique_ID,
+ dont_return_tracked_lab=dont_return_tracked_lab, return_assignments=return_assignments
)
# Check if tracker also returns additional info
+ assignments = None
if isinstance(tracked_result, tuple):
- tracked_lab, tracked_lost_IDs = tracked_result
- self.handleAdditionalInfoRealTimeTracker(prev_rp, tracked_lost_IDs)
+ tracked_lab, add_info = tracked_result
+ assignments = self.handleAdditionalInfoRealTimeTracker(
+ prev_rp, add_info)
+ elif isinstance(tracked_result, dict) and dont_return_tracked_lab:
+ add_info = tracked_result
+ if 'assignments' in add_info: # if still entire add_info is returned
+ assignments = self.handleAdditionalInfoRealTimeTracker(
+ prev_rp, add_info)
+ else:
+ assignments = add_info # its just assignements
else:
tracked_lab = tracked_result
- return tracked_lab
+ if not return_assignments and not dont_return_tracked_lab:
+ return tracked_lab
+
+ # get assignments
+ if assignments is None:
+ assignments = dict()
+ for obj in curr_rp:
+ try:
+ old_lab = obj.label
+ new_lab = tracked_lab[obj.slice][obj.image][0]
+ assignments[old_lab] = new_lab
+ except:
+ import pdb; pdb.set_trace()
+
+ if dont_return_tracked_lab:
+ return assignments
+ return tracked_lab, assignments
def clearAssignedObjsSecondStep(self):
posData = self.data[self.pos_i]
@@ -28681,37 +29179,32 @@ def trackSubsetIDs(self, subsetIDs: Iterable[int]):
posData = self.data[self.pos_i]
if posData.frame_i == 0:
return
-
- subsetLab = np.zeros_like(posData.lab)
- for subsetID in subsetIDs:
- subsetLab[posData.lab == subsetID] = subsetID
prev_lab = posData.allData_li[posData.frame_i-1]['labels']
prev_rp = posData.allData_li[posData.frame_i-1]['regionprops']
- tracked_lab = self.trackFrame(
+ assignments = self.trackFrame(
prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs,
- assign_unique_new_IDs=True
- )
- doUpdateRp = False
- for subsetID in subsetIDs:
- subsetIDmask = posData.lab == subsetID
- trackedID = tracked_lab[subsetIDmask][0]
- if trackedID == subsetID:
- continue
-
- is_manually_edited = False
- for y, x, new_ID in posData.editID_info:
- if new_ID == subsetID:
+ assign_unique_new_IDs=True, specific_IDs=subsetIDs,
+ dont_return_tracked_lab=True
+ )
+ # I think assignments already avoids merging
+ assignments_new = dict()
+ for old_ID, new_ID in assignments.items():
+ # get "old" id based on assignments
+ if old_ID == new_ID:
+ continue # nothing to do
+
+ for y, x, editID in posData.editID_info:
+ if editID == old_ID or editID == new_ID:
# Do not track because it was manually edited
- break
+ continue
- posData.lab[subsetIDmask] = tracked_lab[subsetIDmask]
- doUpdateRp = True
-
- if not doUpdateRp:
- return
+
+ obj = posData.rp.get_obj_from_ID(old_ID) # pr is still old, so we need to get the old ID
+ posData.lab[obj.slice][obj.image] = new_ID
+ assignments_new[old_ID] = new_ID # old ID : new tracked ID
- self.update_rp()
+ self.update_rp(assignments=assignments_new)
def doSkipTracking(self, against_next: bool, enforce: bool):
if self.isSnapshot:
@@ -28760,13 +29253,13 @@ def tracking(
storeUndo=False, prev_lab=None, prev_rp=None,
return_lab=False, assign_unique_new_IDs=True,
separateByLabel=True, wl_update=True,
- IDs=None, against_next=False,
+ against_next=False, specific_IDs=None , return_assignments=False
):
posData = self.data[self.pos_i]
-
+ return_tuple = (None, None) if return_assignments and return_lab else None
if self.doSkipTracking(against_next, enforce):
self.setLostNewOldPrevIDs()
- return
+ return return_tuple
"""Tracking starts here"""
staturBarLabelText = self.statusBarLabel.text()
@@ -28800,41 +29293,54 @@ def tracking(
if posData.frame_i < self.get_last_tracked_i():
unique_ID = self.setBrushID(return_val=True)
- tracked_lab = self.trackFrame(
+ tracked_lab, assignments = self.trackFrame(
prev_lab, prev_rp, posData.lab, posData.rp, posData.IDs,
- assign_unique_new_IDs=assign_unique_new_IDs, IDs=IDs,
- unique_ID=unique_ID
+ assign_unique_new_IDs=assign_unique_new_IDs,
+ unique_ID=unique_ID, specific_IDs=specific_IDs,
+ return_assignments=True
)
if DoManualEdit:
# Correct tracking with manually changed IDs
- rp = skimage.measure.regionprops(tracked_lab)
- IDs = [obj.label for obj in rp]
- self.manuallyEditTracking(tracked_lab, IDs)
+ tracked_lab, assignments = self.manuallyEditTracking(tracked_lab, assignments)
if return_lab:
QTimer.singleShot(50, partial(
self.statusBarLabel.setText, staturBarLabelText
))
+ if return_assignments:
+ return tracked_lab, assignments
return tracked_lab
# Update labels, regionprops and determine new and lost IDs
posData.lab = tracked_lab
- self.update_rp(wl_update=wl_update, )
+ self.update_rp(wl_update=wl_update, assignments=assignments)
self.setAllTextAnnotations()
QTimer.singleShot(50, partial(
self.statusBarLabel.setText, staturBarLabelText
))
+ if return_assignments and return_lab:
+ return tracked_lab, assignments
+ elif return_assignments:
+ return assignments
+ elif return_lab:
+ return tracked_lab
- def handleAdditionalInfoRealTimeTracker(self, prev_rp, *args):
+ def handleAdditionalInfoRealTimeTracker(self, prev_rp, add_info):
+ assignments = None
if self._rtTrackerName == 'CellACDC_normal_division':
- tracked_lost_IDs = args[0]
+ tracked_lost_IDs = add_info['mothers']
self.setTrackedLostCentroids(prev_rp, tracked_lost_IDs)
+ assignments = add_info['assignments']
elif self._rtTrackerName == 'CellACDC_2steps':
- if args[0] is None:
- return
- posData = self.data[self.pos_i]
- posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = args[0]
+ assignments = add_info['assignments']
+ if add_info['to_track_tracked_objs_2nd_step'] is not None:
+ posData = self.data[self.pos_i]
+ posData.acdcTracker2stepsAnnotInfo[posData.frame_i] = add_info['to_track_tracked_objs_2nd_step']
+ elif self._rtTrackerName == 'Cell-ACDC':
+ assignments = add_info['assignments']
+
+ return assignments
def keepOnlyNewIDAssignedObjsSecondStep(self, trackedID):
posData = self.data[self.pos_i]
@@ -28895,7 +29401,11 @@ def annotateAssignedObjsAcdcTrackerSecondStep(self):
new_objs_1st_step, lost_objs_1st_step = annotInfo
for lostObj, newObj in zip(lost_objs_1st_step, new_objs_1st_step):
- allContours = self.getObjContours(lostObj, all_external=True)
+ allContours = self.getObjContours(
+ lostObj,
+ all_external=True,
+ include_internal=self.showAllContoursToggle.isChecked()
+ )
for objContours in allContours:
isObjVisible = self.isObjVisible(newObj.bbox)
if not isObjVisible:
@@ -28930,12 +29440,24 @@ def setTrackedLostCentroids(self, prev_rp, tracked_lost_IDs):
"""
posData = self.data[self.pos_i]
frame_i = posData.frame_i
+ prev_lab = posData.allData_li[frame_i-1]['labels']
for obj in prev_rp:
if obj.label not in tracked_lost_IDs:
continue
-
- int_centroid = tuple([int(val) for val in obj.centroid])
+ if isinstance(prev_rp, regionprops.acdcRegionprops):
+ ID = obj.ID
+ centroid = prev_rp.get_centroid(ID, exact=True)
+ else:
+ centroid = obj.centroid
+ int_centroid = tuple([int(val) for val in centroid])
+ # check if centroid has right ID
+ if prev_lab[int_centroid] != ID:
+ # get closest point with the right ID
+ coords = obj.coords
+ distances = np.sqrt(np.sum((coords - centroid) ** 2, axis=1))
+ closest_idx = np.argmin(distances)
+ int_centroid = tuple([int(val) for val in coords[closest_idx]])
try:
posData.tracked_lost_centroids[frame_i].add(int_centroid)
except KeyError:
@@ -28989,28 +29511,59 @@ def getTrackedLostIDs(self, prev_lab=None, IDs_in_frames=None, frame_i=None):
posData.trackedLostIDs = trackedLostIDs
return trackedLostIDs
-
- def manuallyEditTracking(self, tracked_lab, allIDs):
+
+ def manuallyEditTracking(self, tracked_lab, assignments):
posData = self.data[self.pos_i]
infoToRemove = []
- # Correct tracking with manually changed IDs
- maxID = max(allIDs, default=1)
- for y, x, new_ID in posData.editID_info:
- old_ID = tracked_lab[y, x]
- if old_ID == 0 or old_ID == new_ID:
- infoToRemove.append((y, x, new_ID))
+
+ if not assignments:
+ return tracked_lab, assignments
+
+ # !!! RP is stale so we need to reverse search for the ID
+ reversed_assignments = (
+ {tracked_id: stale_id for stale_id, tracked_id in assignments.items()}
+ if assignments else {}
+ )
+ stale_ids = set(posData.rp.IDs)
+
+ covered_edited_IDs = set()
+ for y, x, edited_ID in posData.editID_info:
+ new_ID = assignments.get(edited_ID, edited_ID) # ID in tracked lab
+ if new_ID in covered_edited_IDs:
+ # This ID has already been edited by sawpping for example
continue
- if new_ID in allIDs:
- tempID = maxID+1
- tracked_lab[tracked_lab == old_ID] = tempID
- tracked_lab[tracked_lab == new_ID] = old_ID
- tracked_lab[tracked_lab == tempID] = new_ID
+
+ if new_ID == 0 or new_ID == edited_ID: # edited ID is not tracked to a different ID
+ infoToRemove.append((y, x, edited_ID))
+ continue
+
+ old_RP_ID = reversed_assignments.get(edited_ID, edited_ID) # ID pre tracking
+ old_obj = posData.rp.get_obj_from_ID(old_RP_ID) # obj pre tracking
+
+ if edited_ID in stale_ids:
+ # a swap has been made by the user between an old ID (old_RP_ID) and a new ID (edited_ID)
+ new_obj = posData.rp.get_obj_from_ID(edited_ID)
+ tracked_lab[old_obj.slice][old_obj.image] = edited_ID
+ tracked_lab[new_obj.slice][new_obj.image] = old_RP_ID
+ # update assignemnets
+ assignments[old_RP_ID] = edited_ID
+ assignments[edited_ID] = old_RP_ID
+ # add the two swapped IDs
+
+ covered_edited_IDs.add(edited_ID)
+ covered_edited_IDs.add(old_RP_ID)
+
else:
- tracked_lab[tracked_lab == old_ID] = new_ID
- if new_ID > maxID:
- maxID = new_ID
+ tracked_lab[old_obj.slice][old_obj.image] = edited_ID
+
+ assignments[old_RP_ID] = edited_ID
+
+ covered_edited_IDs.add(edited_ID)
+
for info in infoToRemove:
posData.editID_info.remove(info)
+
+ return tracked_lab, assignments
def warnReinitLastSegmFrame(self):
current_frame_n = self.navigateScrollBar.value()
@@ -30243,8 +30796,12 @@ def initRealTimeTracker(self, force=False):
rtTracker = aliases[rtTracker]
if rtTracker == 'Cell-ACDC':
+ self._rtTrackerName = 'Cell-ACDC'
+ self.realTimeTracker_kwargs = None # This is hard coded
return
if rtTracker == 'YeaZ':
+ self._rtTrackerName = 'YeaZ'
+ self.realTimeTracker_kwargs = None # This is hard coded
return
if self.isRealTimeTrackerInitialized and not force:
@@ -30262,6 +30819,8 @@ def initRealTimeTracker(self, force=False):
self.realTimeTracker = realTimeTracker
self.track_frame_params = track_frame_params
+ self.realTimeTracker_kwargs = inspect.signature(
+ self.realTimeTracker.track_frame).parameters
self.logger.info(f'{rtTracker} tracker successfully initialized.')
if 'image_channel_name' in self.track_frame_params:
# Remove the channel name since it was already loaded in init_tracker
@@ -30755,6 +31314,44 @@ def annotOptionClickedRight(
self.setDrawAnnotComboboxTextRight(saveSettings=saveSettings)
+ def checkHandleTooManyNewItems(self):
+ posData = self.data[self.pos_i]
+ num_objects = len(posData.rp)
+ if num_objects < 1500:
+ return True
+
+ out = _warnings.warnTooManyNewItems(self, num_objects, self)
+ cancel, switchToLowRes, deactivateAnnot = out
+ if cancel:
+ return False
+
+ if switchToLowRes:
+ self.highLowResAction.setChecked(False)
+ self.changeTextResolution()
+ return True
+
+ if deactivateAnnot:
+ self.annotCcaInfoCheckbox.blockSignals(True)
+ self.annotIDsCheckbox.blockSignals(True)
+ self.annotCcaInfoCheckbox.setChecked(False)
+ self.annotIDsCheckbox.setChecked(False)
+ self.annotCcaInfoCheckbox.blockSignals(False)
+ self.annotIDsCheckbox.blockSignals(False)
+
+ self.annotCcaInfoCheckboxRight.blockSignals(True)
+ self.annotIDsCheckboxRight.blockSignals(True)
+ self.annotCcaInfoCheckboxRight.setChecked(False)
+ self.annotIDsCheckboxRight.setChecked(False)
+ self.annotCcaInfoCheckboxRight.blockSignals(False)
+ self.annotIDsCheckboxRight.blockSignals(False)
+
+ self.textAnnot[0].setCcaAnnot(False)
+ self.textAnnot[0].setLabelAnnot(False)
+ self.textAnnot[1].setCcaAnnot(False)
+ self.textAnnot[1].setLabelAnnot(False)
+ return True
+
+
def setAnnotOptionsCcaMode(self):
self.prevAnnotOptions = self.storeCurrentAnnotOptions_ax1(
return_value=True
@@ -32010,7 +32607,9 @@ def getZoomIDs(self, viewRange=None):
)
zoomLab = skimage.segmentation.clear_border(lab[zoomSlice])
- zoomRp = skimage.measure.regionprops(zoomLab)
+ zoomRp = regionprops.acdcRegionprops(
+ zoomLab, precache_centroids=False
+ )
zoomIDs = [obj.label for obj in zoomRp]
return zoomIDs
diff --git a/cellacdc/load.py b/cellacdc/load.py
index 62a336464..5c6f634a8 100755
--- a/cellacdc/load.py
+++ b/cellacdc/load.py
@@ -40,7 +40,7 @@
from . import io
from . import core
from . import IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
-
+from . import fonts
from . import GUI_INSTALLED
if GUI_INSTALLED:
@@ -1306,6 +1306,7 @@ def __init__(
self.loadLastEntriesMetadata()
self.attempFixBasenameBug()
self.non_aligned_ext = '.tif'
+ self.segmMetadata = None
if filename_ext.endswith('aligned.npz'):
for file in myutils.listdir(self.images_path):
if file.endswith(f'{user_ch_name}.h5'):
@@ -1646,7 +1647,13 @@ def countObjectsInSegmTimelapse(self, categories: set[str] | list[str]):
for frame_i in range(len(self.segm_data)):
lab = self.allData_li[frame_i]['labels']
if lab is not None:
- IDsFrame = self.allData_li[frame_i]['IDs']
+ if hasattr(self.allData_li[frame_i]['regionprops'], 'IDs'):
+ IDsFrame = self.allData_li[frame_i]['regionprops'].IDs
+ else:
+ IDsFrame = [
+ obj.label
+ for obj in self.allData_li[frame_i]['regionprops']
+ ]
if uniqueIDsVisited is not None:
uniqueIDsVisited.update(IDsFrame)
@@ -1866,6 +1873,7 @@ def loadOtherFiles(
new_endname='',
labelBoolSegm=None,
load_whitelistIDs=False,
+ load_segm_info_ini=False
):
self.segmFound = False if load_segm_data else None
self.acdc_df_found = False if load_acdc_df else None
@@ -2069,6 +2077,9 @@ def loadOtherFiles(
if load_whitelistIDs:
self.loadWhitelist()
+
+ if load_segm_info_ini:
+ self.readSegmMetadataIni()
def checkAndFixZsliceSegmInfo(self):
if not hasattr(self, 'segmInfo_df'):
@@ -2322,14 +2333,14 @@ def fromTrackerToAcdcDf(
rp = skimage.measure.regionprops(lab)
for obj in rp:
centroid = obj.centroid
- yc, xc = obj.centroid[-2:]
+ yc, xc = centroid[-2:]
acdc_df.at[(frame_i, obj.label), 'x_centroid'] = int(xc)
acdc_df.at[(frame_i, obj.label), 'y_centroid'] = int(yc)
if len(centroid) == 3:
if 'z_centroid' not in acdc_df.columns:
acdc_df['z_centroid'] = 0
- zc = obj.centroid[0]
+ zc = centroid[0]
acdc_df.at[(frame_i, obj.label), 'z_centroid'] = int(zc)
if not save:
@@ -3059,6 +3070,7 @@ def buildPaths(self):
self.raw_postproc_segm_path = f'{base_path}segm_raw_postproc'
self.post_proc_mot_metrics = f'{base_path}post_proc_mot_metrics'
self.segm_hyperparams_ini_path = f'{base_path}segm_hyperparams.ini'
+ self.segm_metadata_ini_path = f'{base_path}segm_metadata_data.ini'
self.custom_annot_json_path = f'{base_path}custom_annot_params.json'
self.custom_combine_metrics_path = (
f'{base_path}custom_combine_metrics.ini'
@@ -3082,6 +3094,7 @@ def get_tracker_export_path(self, trackerName, ext):
def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX):
if not hasattr(self, 'img_data'):
self.segm_data = None
+ self.single_timepoint_size = None
return
Y, X = self.img_data.shape[-2:]
@@ -3093,7 +3106,16 @@ def setBlankSegmData(self, SizeT, SizeZ, SizeY, SizeX):
elif SizeT > 1:
self.segm_data = np.zeros((SizeT, Y, X), int)
else:
- self.segm_data = np.zeros((Y, X), int)
+ self.segm_data = np.zeros((Y, X), int)
+
+ def getSingleTimepointSegmSize(self):
+ if hasattr(self, 'single_timepoint_size') and self.single_timepoint_size is not None:
+ return self.single_timepoint_size
+ if self.SizeT > 1:
+ self.single_timepoint_size = np.prod(self.segm_data.shape[1:])
+ else: # not sure if time axis is present but would be 1 anyways
+ self.single_timepoint_size = np.prod(self.segm_data.shape)
+ return self.single_timepoint_size
def loadAllImgPaths(self):
tif_paths = []
@@ -3185,7 +3207,7 @@ def askInputMetadata(
self.SizeT, self.SizeZ, self.TimeIncrement,
self.PhysicalSizeZ, self.PhysicalSizeY, self.PhysicalSizeX,
ask_SizeT, ask_TimeIncrement, ask_PhysicalSizes,
- parent=self.parent, font=apps.font, imgDataShape=self.img_data_shape,
+ parent=self.parent, font=fonts.font, imgDataShape=self.img_data_shape,
posData=self, singlePos=singlePos, askSegm3D=askSegm3D,
additionalValues=self._additionalMetadataValues,
forceEnableAskSegm3D=forceEnableAskSegm3D,
@@ -3421,7 +3443,15 @@ def loadWhitelist(self):
self.whitelist = whitelist.Whitelist(
total_frames=self.SizeT,
)
- whitelist_path = self.segm_npz_path.replace('.npz', '_whitelistIDs.json')
+ whitelist_path_legacy = self.segm_npz_path.replace(
+ '.npz', '_whitelistIDs.json')
+ segm_filename = os.path.basename(self.segm_npz_path).replace('.npz', '')
+ segm_add_data_folder = os.path.join(self.images_path, segm_filename)
+ os.makedirs(segm_add_data_folder, exist_ok=True)
+ whitelist_path = os.path.join(segm_add_data_folder, 'whitelistIDs.json')
+ if os.path.exists(whitelist_path_legacy):
+ # move to new path
+ shutil.move(whitelist_path_legacy, whitelist_path)
new_centroids_path = self.segm_npz_path.replace('.npz', '_new_centroids.json')
success = self.whitelist.load(
whitelist_path, new_centroids_path, self.segm_data, self.allData_li,
@@ -3432,7 +3462,96 @@ def loadWhitelist(self):
if not success:
self.whitelist = None
-
+ def readSegmMetadataIni(self):
+ if not os.path.exists(self.segm_metadata_ini_path):
+ return None
+
+ cp = config.ConfigParser()
+ cp.read(self.segm_metadata_ini_path)
+ # one entry for each segmentation file
+ self.segmMetadata = {}
+ for segm_file in cp.sections():
+ sizeX = cp.getint(segm_file, 'sizeX', fallback=None)
+ sizeY = cp.getint(segm_file, 'sizeY', fallback=None)
+ sizeT = cp.getint(segm_file, 'SizeT', fallback=None)
+ sizeZ = cp.getint(segm_file, 'SizeZ', fallback=None)
+ is_3D = sizeZ > 1 if sizeZ is not None else False
+ last_modified_date = cp.get(segm_file, 'last_modified_date', fallback=None)
+ acdc_df_segm = cp.get(segm_file, 'acdc_df_segm', fallback=None)
+ acdc_df_save_date = cp.get(segm_file, 'acdc_df_save_date', fallback=None)
+ self.segmMetadata[segm_file] = {
+ 'SizeT': sizeT,
+ 'SizeZ': sizeZ,
+ 'is_3D': is_3D,
+ 'last_modified_date': last_modified_date,
+ 'acdc_df_segm': acdc_df_segm,
+ 'acdc_df_save_date': acdc_df_save_date,
+ 'sizeX': sizeX,
+ 'sizeY': sizeY,
+ }
+
+ def saveSegmMetadataIni(self):
+ # need to be called in more locations, will be full yimplemented in workflow gui
+ cp = config.ConfigParser()
+ for segm_file, metadata in self.segmMetadata.items():
+ cp[segm_file] = {}
+ cp[segm_file]['SizeT'] = str(metadata.get('SizeT', ''))
+ cp[segm_file]['SizeZ'] = str(metadata.get('SizeZ', ''))
+ cp[segm_file]['last_modified_date'] = str(metadata.get('last_modified_date', ''))
+ cp[segm_file]['acdc_df_segm'] = str(metadata.get('acdc_df_segm', ''))
+ cp[segm_file]['sizeX'] = str(metadata.get('sizeX', ''))
+ cp[segm_file]['sizeY'] = str(metadata.get('sizeY', ''))
+ cp[segm_file]['acdc_df_save_date'] = str(
+ metadata.get('acdc_df_save_date', '')
+ )
+
+ with open(self.segm_metadata_ini_path, 'w') as configfile:
+ cp.write(configfile)
+
+ def updateSegmMetadata(self, segm_file=None, SizeT=None, SizeZ=None,
+ acdc_df_segm=None, last_modified_date=None,
+ sizeY=None, sizeX=None, all=False, acdc_df_save_date=None):
+ if segm_file is None:
+ segm_file = os.path.basename(self.segm_npz_path)
+
+ if self.segmMetadata is None:
+ self.segmMetadata = {}
+ segm_metadata = self.segmMetadata.get(segm_file, {})
+ if SizeT is not None or all:
+ if SizeT is True or SizeT is None:
+ SizeT = self.SizeT
+ segm_metadata['SizeT'] = SizeT
+ if SizeZ is not None or all:
+ if SizeZ is True or SizeZ is None:
+ SizeZ = self.SizeZ if self.isSegm3D else 1
+ segm_metadata['SizeZ'] = SizeZ
+ segm_metadata['is_3D'] = SizeZ > 1
+ if acdc_df_segm is not None or all:
+ if acdc_df_segm is True or acdc_df_segm is None:
+ acdc_df_segm = os.path.basename(self.acdc_output_csv_path) # for future if we allow multpiple outputs
+ # clear other segm metadata entries with acdc_df info to avoid confusion
+ for info in self.segmMetadata.values():
+ if info.get('acdc_df_segm', '') == acdc_df_segm:
+ info['acdc_df_segm'] = None
+ segm_metadata['acdc_df_segm'] = acdc_df_segm
+ if last_modified_date is not None or all:
+ if last_modified_date is True or last_modified_date is None: # explicitly in this cane set curr datetime
+ last_modified_date = datetime.now()
+ segm_metadata['last_modified_date'] = last_modified_date
+ if sizeY is not None or all:
+ if sizeY is True or sizeY is None:
+ sizeY = self.SizeY
+ segm_metadata['sizeY'] = sizeY
+ if sizeX is not None or all:
+ if sizeX is True or sizeX is None:
+ sizeX = self.SizeX
+ segm_metadata['sizeX'] = sizeX
+ if acdc_df_save_date is not None or all:
+ if acdc_df_save_date is True or acdc_df_save_date is None:
+ acdc_df_save_date = datetime.now()
+ segm_metadata['acdc_df_save_date'] = acdc_df_save_date
+ self.segmMetadata[segm_file] = segm_metadata
+
class select_exp_folder:
def __init__(self):
self.exp_path = None
diff --git a/cellacdc/myutils.py b/cellacdc/myutils.py
index f2695bcd3..51ae24e98 100644
--- a/cellacdc/myutils.py
+++ b/cellacdc/myutils.py
@@ -55,6 +55,7 @@
from . import urls
from . import qrc_resources_path
from . import settings_folderpath
+from . import regionprops
from .models._cellpose_base import min_target_versions_cp
if GUI_INSTALLED:
@@ -344,6 +345,7 @@ def __init__(
level=logging.DEBUG
):
super().__init__(f'{name}-{module}', level=level)
+ self.propagate = False # prevent UnicodeEncodeError via root StreamHandler
self._stdout = sys.stdout
self._stderr = StdErr(logger=self)
sys.stderr = self._stderr
@@ -367,8 +369,13 @@ def write(self, text, log_to_file=True, write_to_stdout=True):
log_to_file : bool, optional
If True, call `info` method with `text`. Default is True
"""
- if write_to_stdout:
- self._stdout.write(text)
+ if write_to_stdout:
+ try:
+ self._stdout.write(text)
+ except UnicodeEncodeError:
+ self._stdout.write(text.encode(
+ self._stdout.encoding, errors='replace'
+ ).decode(self._stdout.encoding))
if not log_to_file:
return
@@ -378,8 +385,11 @@ def write(self, text, log_to_file=True, write_to_stdout=True):
if not text:
return
-
- self.debug(text)
+
+ try:
+ self.debug(text)
+ except UnicodeEncodeError:
+ self.debug(text.encode('ascii', errors='replace').decode('ascii'))
def close(self):
for handler in self.handlers:
@@ -614,7 +624,7 @@ def setupLogger(module='base', logs_path=None, caller='Cell-ACDC'):
log_filename = f'{date_time}_{module}_{id}_stdout.log'
log_path = os.path.join(logs_path, log_filename)
- output_file_handler = logging.FileHandler(log_path, mode='w')
+ output_file_handler = logging.FileHandler(log_path, mode='w', encoding='utf-8')
# Format your logs (optional)
formatter = logging.Formatter(
@@ -1033,22 +1043,6 @@ def showInExplorer(path):
else:
os.startfile(path)
-def exec_time(func):
- @wraps(func)
- def inner_function(self, *args, **kwargs):
- t0 = time.perf_counter()
- if func.__code__.co_argcount==1 and func.__defaults__ is None:
- result = func(self)
- elif func.__code__.co_argcount>1 and func.__defaults__ is None:
- result = func(self, *args)
- else:
- result = func(self, *args, **kwargs)
- t1 = time.perf_counter()
- s = f'{func.__name__} execution time = {(t1-t0)*1000:.3f} ms'
- printl(s, is_decorator=True)
- return result
- return inner_function
-
def setRetainSizePolicy(widget, retain=True):
sp = widget.sizePolicy()
sp.setRetainSizeWhenHidden(retain)
@@ -1111,22 +1105,61 @@ def get_chname_from_basename(filename, basename, remove_ext=True):
chName = chName[:aligned_idx]
return chName
+def _edge_ids_2d(lab):
+ border_labels = np.r_[
+ lab[0, :],
+ lab[-1, :],
+ lab[:, 0],
+ lab[:, -1],
+ ]
+ return np.unique(border_labels[border_labels != 0])
+
+def _edge_ids_3d(lab):
+ face_labels = np.r_[
+ lab[ 0, :, :].ravel(), # z min
+ lab[-1, :, :].ravel(), # z max
+ lab[:, 0, :].ravel(), # y min
+ lab[:, -1, :].ravel(), # y max
+ lab[:, :, 0].ravel(), # x min
+ lab[:, :, -1].ravel(), # x max
+ ]
+ ids = np.unique(face_labels)
+ return ids[ids != 0]
+
+def get_edge_ids(lab):
+ if lab.ndim == 2:
+ return _edge_ids_2d(lab)
+ elif lab.ndim == 3:
+ return _edge_ids_3d(lab)
+ else:
+ raise ValueError('Label array must be either 2D or 3D.')
+
+def clear_border(lab, return_edge_ids=False):
+ # probably faster than skimage since it avoids relabeling...
+ # assumes continous unique IDs, which we have. Modifies inplace!
+ edge_ids = get_edge_ids(lab)
+ lab[np.isin(lab, edge_ids)] = 0
+ if return_edge_ids:
+ return edge_ids
+
def getBaseAcdcDf(rp):
zeros_list = [0]*len(rp)
nones_list = [None]*len(rp)
minus1_list = [-1]*len(rp)
- IDs = []
- xx_centroid = []
- yy_centroid = []
- zz_centroid = []
- for obj in rp:
- xc, yc = obj.centroid[-2:]
- IDs.append(obj.label)
- xx_centroid.append(xc)
- yy_centroid.append(yc)
- if len(obj.centroid) == 3:
- zc = obj.centroid[0]
- zz_centroid.append(zc)
+ IDs = [0]*len(rp)
+ xx_centroid = [0]*len(rp)
+ yy_centroid = [0]*len(rp)
+ zz_centroid = [0]*len(rp)
+
+ for i, obj in enumerate(rp):
+ centroid = obj.centroid
+ xc, yc = centroid[-2:]
+ IDs[i] = obj.label
+ xx_centroid[i] = xc
+ yy_centroid[i] = yc
+ if len(centroid) == 3:
+ zc = centroid[0]
+ zz_centroid[i] = zc
df = pd.DataFrame(
{
@@ -1138,7 +1171,7 @@ def getBaseAcdcDf(rp):
'was_manually_edited': minus1_list
}
).set_index('Cell_ID')
- if zz_centroid:
+ if len(centroid) == 3:
df['z_centroid'] = zz_centroid
return df
@@ -1705,7 +1738,7 @@ def download_java():
def get_model_path(model_name, create_temp_dir=True):
if model_name == 'Automatic thresholding':
- model_name == 'thresholding'
+ model_name = 'thresholding'
model_info_path = os.path.join(cellacdc_path, 'models', model_name, 'model')
@@ -2337,7 +2370,7 @@ def lab2d_to_rois(ImagejRoi, lab2D, ndigits, t=None, z=None):
rp = skimage.measure.regionprops(lab2D)
rois = []
for obj in rp:
- cont = core.get_obj_contours(obj)
+ cont = core.get_obj_contours(obj=obj)
yc, xc = obj.centroid
x_str = str((int(xc))).zfill(ndigits)
y_str = str((int(yc))).zfill(ndigits)
@@ -4113,13 +4146,31 @@ def init_tracker(
return tracker, track_params
def import_segment_module(model_name):
+ original_model_name = model_name
+ if model_name == 'Automatic thresholding':
+ model_name = 'thresholding'
+
try:
acdcSegment = import_module(f'cellacdc.models.{model_name}.acdcSegment')
except ModuleNotFoundError as e:
+ # Do not mask missing dependencies imported by the module itself.
+ expected_missing_module = f'cellacdc.models.{model_name}'
+ if e.name != expected_missing_module:
+ raise
+
# Check if custom model
cp = config.ConfigParser()
cp.read(models_list_file_path)
- model_path = cp[model_name]['path']
+ model_key = None
+ for key in (original_model_name, model_name):
+ if key in cp:
+ model_key = key
+ break
+
+ if model_key is None:
+ raise
+
+ model_path = cp[model_key]['path']
spec = importlib.util.spec_from_file_location('acdcSegment', model_path)
acdcSegment = importlib.util.module_from_spec(spec)
sys.modules['acdcSegment'] = acdcSegment
@@ -5121,7 +5172,6 @@ def get_empty_stored_data_dict():
'delROIs_info': {
'rois': [], 'delMasks': [], 'delIDsROI': [], 'state': []
},
- 'IDs': [],
'manually_edited_lab': {'lab': {}, 'zoom_slice': None}
}
@@ -5503,7 +5553,7 @@ def find_distances_ID(rps, point=None, ID=None):
if ID is not None and point is None:
try:
- point = [rp.centroid for rp in rps if rp.label == ID][0]
+ point = rps.get_centroid(ID)
except IndexError:
raise ValueError(f'ID {ID} not found in regionprops (list of cells).')
@@ -5515,7 +5565,7 @@ def find_distances_ID(rps, point=None, ID=None):
point = point[::-1] # rp are in (y, x) format (or (z, y, x) for 3D data) so I need to reverse order
point = np.array([point])
- centroids = np.array([rp.centroid for rp in rps])
+ centroids = np.array([rps.get_centroid(ID) for ID in rps.IDs])
diff = point[:, np.newaxis] - centroids
dist_matrix = np.linalg.norm(diff, axis=2)
return dist_matrix
@@ -5552,7 +5602,7 @@ def sort_IDs_dist(rps, point=None, ID=None):
"""
if ID is not None and point is None:
try:
- point = [rp.centroid for rp in rps if rp.label == ID][0]
+ point = rps.get_centroid(ID)
except IndexError:
raise ValueError(f'ID {ID} not found in regionprops (list of cells).')
@@ -5563,7 +5613,7 @@ def sort_IDs_dist(rps, point=None, ID=None):
raise ValueError('Only one of ID or point must be provided.')
- IDs = [rp.label for rp in rps]
+ IDs = rps.IDs
if len(IDs) == 0:
return []
elif len(IDs) == 1:
diff --git a/cellacdc/plot.py b/cellacdc/plot.py
index 3d33b3c23..34ffe812d 100644
--- a/cellacdc/plot.py
+++ b/cellacdc/plot.py
@@ -15,6 +15,7 @@
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
+from . import debugutils
from tqdm import tqdm
from . import GUI_INSTALLED
@@ -25,7 +26,7 @@
from . import printl
from . import _core, error_below, error_close
-from . import _run, core, myutils
+from . import _run, core, myutils, regionprops as acdc_regionprops
def matplotlib_cmap_to_lut(
cmap: Union[Iterable, matplotlib.colors.Colormap, str],
@@ -768,19 +769,22 @@ def plt_contours(
clear_borders=True, obj_contours_kwargs=None
):
if rp is None:
- rp = skimage.measure.regionprops(lab)
+ rp = acdc_regionprops.acdcRegionprops(lab, precache_centroids=False)
if plot_kwargs is None:
plot_kwargs = {}
if obj_contours_kwargs is None:
obj_contours_kwargs = {}
+ elif 'include_internal' in obj_contours_kwargs:
+ obj_contours_kwargs = obj_contours_kwargs.copy()
+ obj_contours_kwargs['all'] = obj_contours_kwargs.pop('include_internal')
for obj in rp:
if only_IDs is not None and obj.label not in only_IDs:
continue
- contours = core.get_obj_contours(obj, **obj_contours_kwargs)
+ contours = core.get_obj_contours(obj=obj, **obj_contours_kwargs)
if not isinstance(contours, list):
contours = [contours]
diff --git a/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd
new file mode 100644
index 000000000..6728157cd
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd differ
diff --git a/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd
new file mode 100644
index 000000000..951ad121d
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd differ
diff --git a/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd
new file mode 100644
index 000000000..9d5d0547f
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd differ
diff --git a/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd
new file mode 100644
index 000000000..c0ad5b52b
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so
new file mode 100644
index 000000000..0b70d8007
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 000000000..e84696138
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so
new file mode 100644
index 000000000..9e46a0cb0
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so
new file mode 100644
index 000000000..96b91d778
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so
new file mode 100644
index 000000000..2a82f1217
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so
new file mode 100644
index 000000000..14c6dd994
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so
new file mode 100644
index 000000000..d715591ef
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so differ
diff --git a/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so
new file mode 100644
index 000000000..f19d4d7f0
Binary files /dev/null and b/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so differ
diff --git a/cellacdc/precompiled_functions.pyx b/cellacdc/precompiled_functions.pyx
new file mode 100644
index 000000000..f95298411
--- /dev/null
+++ b/cellacdc/precompiled_functions.pyx
@@ -0,0 +1,493 @@
+# precompiled_functions.pyx
+# cython: boundscheck=False, wraparound=False, cdivision=True
+# rand change to trigger gh actions: 2
+import numpy as np
+cimport numpy as np
+from libc.limits cimport UINT_MAX
+
+def find_all_objects_2D(np.uint32_t[:, :] label_img):
+ cdef Py_ssize_t n_rows = label_img.shape[0]
+ cdef Py_ssize_t n_cols = label_img.shape[1]
+ cdef Py_ssize_t i, j
+ cdef unsigned int label, max_label = 0
+ cdef unsigned int capacity = 300, new_cap
+
+ cdef np.ndarray[np.uint32_t, ndim=1] _rs = np.full(capacity, UINT_MAX, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _re = np.zeros(capacity, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _cs = np.full(capacity, UINT_MAX, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _ce = np.zeros(capacity, dtype=np.uint32)
+
+ cdef unsigned int[:] rs = _rs, re = _re, cs = _cs, ce = _ce
+
+ # Single pass: compute bounding boxes, growing arrays in 300-label steps if needed
+ for i in range(n_rows):
+ for j in range(n_cols):
+ label = label_img[i, j]
+ if label == 0:
+ continue
+ if label >= capacity:
+ new_cap = ((label // 300) + 1) * 300
+ _rs = np.concatenate((_rs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32)))
+ _re = np.concatenate((_re, np.zeros(new_cap - capacity, dtype=np.uint32)))
+ _cs = np.concatenate((_cs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32)))
+ _ce = np.concatenate((_ce, np.zeros(new_cap - capacity, dtype=np.uint32)))
+ rs = _rs; re = _re; cs = _cs; ce = _ce
+ capacity = new_cap
+ if label > max_label:
+ max_label = label
+ if i < rs[label]: rs[label] = i
+ if i + 1 > re[label]: re[label] = (i + 1)
+ if j < cs[label]: cs[label] = j
+ if j + 1 > ce[label]: ce[label] = (j + 1)
+
+ if max_label == 0:
+ return [], []
+
+ # Collect present labels into compact numpy arrays (avoids per-label tuple allocation)
+ cdef unsigned int n_labels = 0
+ for lbl in range(1, max_label + 1):
+ if re[lbl] != 0:
+ n_labels += 1
+
+ cdef np.ndarray[np.uint32_t, ndim=1] out_labels = np.empty(n_labels, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=2] out_bboxes = np.empty((n_labels, 4), dtype=np.uint32)
+ cdef unsigned int idx = 0
+ for lbl in range(1, max_label + 1):
+ if re[lbl] != 0:
+ out_labels[idx] = lbl
+ out_bboxes[idx, 0] = rs[lbl]
+ out_bboxes[idx, 1] = re[lbl]
+ out_bboxes[idx, 2] = cs[lbl]
+ out_bboxes[idx, 3] = ce[lbl]
+ idx += 1
+ return out_labels, out_bboxes
+
+def find_all_objects_3D(np.uint32_t[:, :, :] label_img):
+ cdef Py_ssize_t n_z = label_img.shape[0]
+ cdef Py_ssize_t n_rows = label_img.shape[1]
+ cdef Py_ssize_t n_cols = label_img.shape[2]
+ cdef Py_ssize_t i, j, k
+ cdef unsigned int label, max_label = 0
+ cdef unsigned int capacity = 300, new_cap
+
+ cdef np.ndarray[np.uint32_t, ndim=1] _zs = np.full(capacity, UINT_MAX, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _ze = np.zeros(capacity, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _rs = np.full(capacity, UINT_MAX, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _re = np.zeros(capacity, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _cs = np.full(capacity, UINT_MAX, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=1] _ce = np.zeros(capacity, dtype=np.uint32)
+
+ cdef unsigned int[:] zs = _zs, ze = _ze, rs = _rs, re = _re, cs = _cs, ce = _ce
+
+ # Single pass: compute bounding boxes, growing arrays in 300-label steps if needed
+ for i in range(n_z):
+ for j in range(n_rows):
+ for k in range(n_cols):
+ label = label_img[i, j, k]
+ if label == 0:
+ continue
+ if label >= capacity:
+ new_cap = ((label // 300) + 1) * 300
+ _zs = np.concatenate((_zs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32)))
+ _ze = np.concatenate((_ze, np.zeros(new_cap - capacity, dtype=np.uint32)))
+ _rs = np.concatenate((_rs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32)))
+ _re = np.concatenate((_re, np.zeros(new_cap - capacity, dtype=np.uint32)))
+ _cs = np.concatenate((_cs, np.full(new_cap - capacity, UINT_MAX, dtype=np.uint32)))
+ _ce = np.concatenate((_ce, np.zeros(new_cap - capacity, dtype=np.uint32)))
+ zs = _zs; ze = _ze; rs = _rs; re = _re; cs = _cs; ce = _ce
+ capacity = new_cap
+ if label > max_label:
+ max_label = label
+ if i < zs[label]: zs[label] = i
+ if i + 1 > ze[label]: ze[label] = (i + 1)
+ if j < rs[label]: rs[label] = j
+ if j + 1 > re[label]: re[label] = (j + 1)
+ if k < cs[label]: cs[label] = k
+ if k + 1 > ce[label]: ce[label] = (k + 1)
+
+ if max_label == 0:
+ return [], []
+
+ # Collect present labels into compact numpy arrays (avoids per-label tuple allocation)
+ cdef unsigned int n_labels = 0
+ for lbl in range(1, max_label + 1):
+ if ze[lbl] != 0:
+ n_labels += 1
+
+ cdef np.ndarray[np.uint32_t, ndim=1] out_labels = np.empty(n_labels, dtype=np.uint32)
+ cdef np.ndarray[np.uint32_t, ndim=2] out_bboxes = np.empty((n_labels, 6), dtype=np.uint32)
+ cdef unsigned int idx = 0
+ for lbl in range(1, max_label + 1):
+ if ze[lbl] != 0:
+ out_labels[idx] = lbl
+ out_bboxes[idx, 0] = zs[lbl]
+ out_bboxes[idx, 1] = ze[lbl]
+ out_bboxes[idx, 2] = rs[lbl]
+ out_bboxes[idx, 3] = re[lbl]
+ out_bboxes[idx, 4] = cs[lbl]
+ out_bboxes[idx, 5] = ce[lbl]
+ idx += 1
+ return out_labels, out_bboxes
+
+def most_common_projection_3D(np.uint32_t[:, :, :] lab, int axis):
+ """Most-common-value projection for a 3-D label image along `axis`.
+
+ Tie-break matches np.unique(..., return_counts=True) + np.argmax(counts),
+ i.e. the smallest label wins when counts are equal.
+ """
+ if axis < 0 or axis > 2:
+ raise ValueError(f'axis must be 0, 1, or 2. Got {axis}.')
+
+ cdef Py_ssize_t z = lab.shape[0]
+ cdef Py_ssize_t y = lab.shape[1]
+ cdef Py_ssize_t x = lab.shape[2]
+ cdef Py_ssize_t i, j, a, b, depth
+ cdef unsigned int v, vv
+ cdef unsigned int best_label, best_count, curr_count
+ cdef bint seen
+ cdef np.uint32_t[:, :] out_view
+
+ if axis == 0:
+ depth = z
+ out = np.empty((y, x), dtype=np.uint32)
+ out_view = out
+ for i in range(y):
+ for j in range(x):
+ best_count = 0
+ best_label = 0
+ for a in range(depth):
+ v = lab[a, i, j]
+ if v == 0:
+ continue
+ seen = False
+ for b in range(a):
+ if lab[b, i, j] == v:
+ seen = True
+ break
+ if seen:
+ continue
+
+ # Count all remaining occurrences of this label along the full axis.
+ curr_count = 1
+ for b in range(a + 1, depth):
+ if lab[b, i, j] == v:
+ curr_count += 1
+
+ if curr_count > best_count or (curr_count == best_count and v < best_label):
+ best_count = curr_count
+ best_label = v
+
+ out_view[i, j] = best_label
+ return out
+
+ if axis == 1:
+ depth = y
+ out = np.empty((z, x), dtype=np.uint32)
+ out_view = out
+ for i in range(z):
+ for j in range(x):
+ best_count = 0
+ best_label = 0
+ for a in range(depth):
+ v = lab[i, a, j]
+ if v == 0:
+ continue
+ seen = False
+ for b in range(a):
+ if lab[i, b, j] == v:
+ seen = True
+ break
+ if seen:
+ continue
+
+ curr_count = 1
+ for b in range(a + 1, depth):
+ if lab[i, b, j] == v:
+ curr_count += 1
+
+ if curr_count > best_count or (curr_count == best_count and v < best_label):
+ best_count = curr_count
+ best_label = v
+
+ out_view[i, j] = best_label
+ return out
+
+ depth = x
+ out = np.empty((z, y), dtype=np.uint32)
+ out_view = out
+ for i in range(z):
+ for j in range(y):
+ best_count = 0
+ best_label = 0
+ for a in range(depth):
+ v = lab[i, j, a]
+ if v == 0:
+ continue
+ seen = False
+ for b in range(a):
+ vv = lab[i, j, b]
+ if vv == v:
+ seen = True
+ break
+ if seen:
+ continue
+
+ curr_count = 1
+ for b in range(a + 1, depth):
+ vv = lab[i, j, b]
+ if vv == v:
+ curr_count += 1
+
+ if curr_count > best_count or (curr_count == best_count and v < best_label):
+ best_count = curr_count
+ best_label = v
+
+ out_view[i, j] = best_label
+ return out
+
+def object_projections_and_size_3D(
+ np.uint32_t[:, :, :] cutout,
+ unsigned int obj_id,
+):
+ """Return binary XY/XZ/YZ projections and voxel count for specified object in a cutout."""
+ cdef Py_ssize_t z = cutout.shape[0]
+ cdef Py_ssize_t y = cutout.shape[1]
+ cdef Py_ssize_t x = cutout.shape[2]
+ cdef Py_ssize_t i, j, k
+ cdef unsigned int size = 0
+
+ cdef np.ndarray[np.uint8_t, ndim=2] proj_z = np.zeros((y, x), dtype=np.uint8)
+ cdef np.ndarray[np.uint8_t, ndim=2] proj_y = np.zeros((z, x), dtype=np.uint8)
+ cdef np.ndarray[np.uint8_t, ndim=2] proj_x = np.zeros((z, y), dtype=np.uint8)
+ cdef np.uint8_t[:, :] proj_z_view = proj_z
+ cdef np.uint8_t[:, :] proj_y_view = proj_y
+ cdef np.uint8_t[:, :] proj_x_view = proj_x
+
+ for i in range(z):
+ for j in range(y):
+ for k in range(x):
+ if cutout[i, j, k] != obj_id:
+ continue
+ size += 1
+ proj_z_view[j, k] = 1
+ proj_y_view[i, k] = 1
+ proj_x_view[i, j] = 1
+
+ return proj_z, proj_y, proj_x, size
+
+def object_projection_and_size_3D(
+ np.uint32_t[:, :, :] cutout,
+ unsigned int obj_id,
+ int axis,
+):
+ """Return one binary projection and voxel count for one object in a 3-D cutout.
+
+ axis=0 -> XY projection (collapse z)
+ axis=1 -> XZ projection (collapse y)
+ axis=2 -> YZ projection (collapse x)
+ """
+ if axis < 0 or axis > 2:
+ raise ValueError(f'axis must be 0, 1, or 2. Got {axis}.')
+
+ cdef Py_ssize_t z = cutout.shape[0]
+ cdef Py_ssize_t y = cutout.shape[1]
+ cdef Py_ssize_t x = cutout.shape[2]
+ cdef Py_ssize_t i, j, k
+ cdef unsigned int size = 0
+
+ cdef np.ndarray[np.uint8_t, ndim=2] proj
+ cdef np.uint8_t[:, :] proj_view
+
+ if axis == 0:
+ proj = np.zeros((y, x), dtype=np.uint8)
+ proj_view = proj
+ for i in range(z):
+ for j in range(y):
+ for k in range(x):
+ if cutout[i, j, k] != obj_id:
+ continue
+ size += 1
+ proj_view[j, k] = 1
+ return proj, size
+
+ if axis == 1:
+ proj = np.zeros((z, x), dtype=np.uint8)
+ proj_view = proj
+ for i in range(z):
+ for j in range(y):
+ for k in range(x):
+ if cutout[i, j, k] != obj_id:
+ continue
+ size += 1
+ proj_view[i, k] = 1
+ return proj, size
+
+ proj = np.zeros((z, y), dtype=np.uint8)
+ proj_view = proj
+ for i in range(z):
+ for j in range(y):
+ for k in range(x):
+ if cutout[i, j, k] != obj_id:
+ continue
+ size += 1
+ proj_view[i, j] = 1
+
+ return proj, size
+
+def calc_IoA_matrix_2D(
+ np.uint32_t[:, :] lab,
+ np.uint32_t[:, :] prev_lab,
+ np.uint32_t[:] curr_IDs,
+ np.uint32_t[:] prev_IDs,
+ np.uint32_t[:] prev_areas,
+ np.uint32_t[:] curr_areas,
+ bint use_union,
+):
+ """Single-pass IoA matrix between two 2-D label images.
+
+ Parameters
+ ----------
+ lab, prev_lab : (Y, X) uint32 label images for current and previous frame.
+ curr_IDs : 1-D array of current object labels (row order of output).
+ prev_IDs : 1-D array of previous object labels (col order of output).
+ prev_areas : pixel area of each entry in prev_IDs.
+ curr_areas : pixel area of each entry in curr_IDs (only used when use_union=True).
+ use_union : if False, denominator is area_prev; if True, denominator is union.
+
+ Returns
+ -------
+ IoA_matrix : (n_curr, n_prev) float64 array.
+ """
+ cdef Py_ssize_t n_rows = lab.shape[0]
+ cdef Py_ssize_t n_cols = lab.shape[1]
+ cdef Py_ssize_t n_curr = curr_IDs.shape[0]
+ cdef Py_ssize_t n_prev = prev_IDs.shape[0]
+ cdef Py_ssize_t i, j, ci, pi
+ cdef unsigned int c, p, max_curr_label = 0, max_prev_label = 0
+ cdef int ci_val, pi_val
+
+ for i in range(n_curr):
+ if curr_IDs[i] > max_curr_label:
+ max_curr_label = curr_IDs[i]
+ for i in range(n_prev):
+ if prev_IDs[i] > max_prev_label:
+ max_prev_label = prev_IDs[i]
+
+ # label -> matrix-index lookup; -1 means "not in the tracked set"
+ cdef np.ndarray[np.int32_t, ndim=1] _curr_idx = np.full(max_curr_label + 1, -1, dtype=np.int32)
+ cdef np.ndarray[np.int32_t, ndim=1] _prev_idx = np.full(max_prev_label + 1, -1, dtype=np.int32)
+ cdef int[:] curr_idx = _curr_idx
+ cdef int[:] prev_idx = _prev_idx
+
+ for i in range(n_curr):
+ curr_idx[curr_IDs[i]] = i
+ for i in range(n_prev):
+ prev_idx[prev_IDs[i]] = i
+
+ cdef np.ndarray[np.uint32_t, ndim=2] _intersections = np.zeros((n_curr, n_prev), dtype=np.uint32)
+ cdef unsigned int[:, :] intersections = _intersections
+
+ # Single pass: count overlapping pixels between every (curr, prev) pair
+ for i in range(n_rows):
+ for j in range(n_cols):
+ c = lab[i, j]
+ p = prev_lab[i, j]
+ if c == 0 or p == 0:
+ continue
+ if c > max_curr_label or p > max_prev_label:
+ continue
+ ci_val = curr_idx[c]
+ pi_val = prev_idx[p]
+ if ci_val < 0 or pi_val < 0:
+ continue
+ intersections[ci_val, pi_val] += 1
+
+ cdef np.ndarray[np.float64_t, ndim=2] IoA_matrix = np.zeros((n_curr, n_prev), dtype=np.float64)
+ cdef double denom_val, I_val
+
+ for ci in range(n_curr):
+ for pi in range(n_prev):
+ I_val = intersections[ci, pi]
+ if I_val == 0.0:
+ continue
+ if use_union:
+ denom_val = (curr_areas[ci] + prev_areas[pi]) - I_val
+ else:
+ denom_val = prev_areas[pi]
+ if denom_val == 0.0:
+ continue
+ IoA_matrix[ci, pi] = I_val / denom_val
+
+ return IoA_matrix
+
+def calc_IoA_matrix_3D(
+ np.uint32_t[:, :, :] lab,
+ np.uint32_t[:, :, :] prev_lab,
+ np.uint32_t[:] curr_IDs,
+ np.uint32_t[:] prev_IDs,
+ np.uint32_t[:] prev_areas,
+ np.uint32_t[:] curr_areas,
+ bint use_union,
+):
+ """Single-pass IoA matrix between two 3-D label images. See calc_IoA_matrix_2D."""
+ cdef Py_ssize_t n_z = lab.shape[0]
+ cdef Py_ssize_t n_rows = lab.shape[1]
+ cdef Py_ssize_t n_cols = lab.shape[2]
+ cdef Py_ssize_t n_curr = curr_IDs.shape[0]
+ cdef Py_ssize_t n_prev = prev_IDs.shape[0]
+ cdef Py_ssize_t i, j, k, ci, pi
+ cdef unsigned int c, p, max_curr_label = 0, max_prev_label = 0
+ cdef int ci_val, pi_val
+
+ for i in range(n_curr):
+ if curr_IDs[i] > max_curr_label:
+ max_curr_label = curr_IDs[i]
+ for i in range(n_prev):
+ if prev_IDs[i] > max_prev_label:
+ max_prev_label = prev_IDs[i]
+
+ cdef np.ndarray[np.int32_t, ndim=1] _curr_idx = np.full(max_curr_label + 1, -1, dtype=np.int32)
+ cdef np.ndarray[np.int32_t, ndim=1] _prev_idx = np.full(max_prev_label + 1, -1, dtype=np.int32)
+ cdef int[:] curr_idx = _curr_idx
+ cdef int[:] prev_idx = _prev_idx
+
+ for i in range(n_curr):
+ curr_idx[curr_IDs[i]] = i
+ for i in range(n_prev):
+ prev_idx[prev_IDs[i]] = i
+
+ cdef np.ndarray[np.uint32_t, ndim=2] _intersections = np.zeros((n_curr, n_prev), dtype=np.uint32)
+ cdef unsigned int[:, :] intersections = _intersections
+
+ for i in range(n_z):
+ for j in range(n_rows):
+ for k in range(n_cols):
+ c = lab[i, j, k]
+ p = prev_lab[i, j, k]
+ if c == 0 or p == 0:
+ continue
+ if c > max_curr_label or p > max_prev_label:
+ continue
+ ci_val = curr_idx[c]
+ pi_val = prev_idx[p]
+ if ci_val < 0 or pi_val < 0:
+ continue
+ intersections[ci_val, pi_val] += 1
+
+ cdef np.ndarray[np.float64_t, ndim=2] IoA_matrix = np.zeros((n_curr, n_prev), dtype=np.float64)
+ cdef double denom_val, I_val
+
+ for ci in range(n_curr):
+ for pi in range(n_prev):
+ I_val = intersections[ci, pi]
+ if I_val == 0.0:
+ continue
+ if use_union:
+ denom_val = (curr_areas[ci] + prev_areas[pi]) - I_val
+ else:
+ denom_val = prev_areas[pi]
+ if denom_val == 0.0:
+ continue
+ IoA_matrix[ci, pi] = I_val / denom_val
+
+ return IoA_matrix
\ No newline at end of file
diff --git a/cellacdc/regionprops.py b/cellacdc/regionprops.py
new file mode 100644
index 000000000..4f449db94
--- /dev/null
+++ b/cellacdc/regionprops.py
@@ -0,0 +1,1203 @@
+import numpy as np
+from scipy import ndimage as ndi
+import skimage.measure
+import cv2
+from . import printl, debugutils
+from skimage.measure._regionprops_utils import (
+ _normalize_spacing,
+)
+import traceback as traceback
+
+try:
+ from cellacdc.precompiled.precompiled_functions import (
+ find_all_objects_2D,
+ find_all_objects_3D,
+ object_projections_and_size_3D,
+ object_projection_and_size_3D,
+ )
+ _CYTHON_FIND_OBJECTS = True
+ _CYTHON_OBJECT_PROJECTIONS = True
+ print('regionprops: imported precompiled find-objects helpers.')
+except Exception:
+ _CYTHON_FIND_OBJECTS = False
+ _CYTHON_OBJECT_PROJECTIONS = False
+ print('[WARNING]: regionprops could not import precompiled find-objects helpers, falling back to scipy.ndimage.find_objects.')
+
+try:
+ from cellacdc.precompiled.precompiled_functions import most_common_projection_3D
+ _CYTHON_MOST_COMMON_PROJECTION = True
+ print('regionprops: imported precompiled most-common projection helper.')
+except Exception:
+ _CYTHON_MOST_COMMON_PROJECTION = False
+ print('[WARNING]: regionprops could not import precompiled most-common projection helper, falling back to NumPy implementation.')
+
+# WARNING: Developers have already used
+# 14 hrs
+# to optimize this.
+# In addition, implementing these optimizations in the codebase took
+# 9 hrs
+# Specifically the
+# centroid (huge gain for 3D data)
+# contour caching
+# better find objects implementation to avoid iterating over None lists
+# bbox caching
+# targeted updates to RP
+# stuff was targeted.
+# If you decide to try and optimize it further, please update this warning :)
+
+_RegionProperties = skimage.measure._regionprops.RegionProperties
+_cached = skimage.measure._regionprops._cached
+
+def _most_common_projection_ignore_zero_numpy(lab, axis):
+ """Most-common projection that ignores label 0 unless all values are 0."""
+ moved = np.moveaxis(lab, axis, 0)
+ depth = moved.shape[0]
+ flat = moved.reshape(depth, -1)
+ out = np.zeros(flat.shape[1], dtype=lab.dtype)
+
+ for col in range(flat.shape[1]):
+ line = flat[:, col]
+ nonzero = line[line != 0]
+ if nonzero.size == 0:
+ continue
+
+ labels, counts = np.unique(nonzero, return_counts=True)
+ # np.unique sorts ascending, so ties are resolved by smallest label.
+ out[col] = labels[np.argmax(counts)]
+
+ return out.reshape(moved.shape[1:])
+
+def _object_projection_and_size_numpy(cutout, obj_id, axis):
+ mask = cutout == obj_id
+ proj = np.any(mask, axis=axis).astype(np.uint8)
+ size = int(np.count_nonzero(mask))
+ return proj, size
+
+def _acdc_regionprops_factory(
+ label_image,
+ intensity_image=None,
+ cache=True,
+ *,
+ extra_properties=None,
+ spacing=None,
+ offset=None,
+ ):
+ if label_image.ndim not in (2, 3):
+ raise TypeError('Only 2-D and 3-D images supported.')
+
+ if not np.issubdtype(label_image.dtype, np.integer):
+ if np.issubdtype(label_image.dtype, bool):
+ raise TypeError(
+ 'Non-integer image types are ambiguous: '
+ 'use skimage.measure.label to label the connected '
+ 'components of label_image, '
+ 'or label_image.astype(np.uint8) to interpret '
+ 'the True values as a single label.'
+ )
+ raise TypeError('Non-integer label_image types are ambiguous')
+
+ if offset is None:
+ offset_arr = np.zeros((label_image.ndim,), dtype=int)
+ else:
+ offset_arr = np.asarray(offset)
+ if offset_arr.ndim != 1 or offset_arr.size != label_image.ndim:
+ raise ValueError(
+ 'Offset should be an array-like of integers '
+ 'of shape (label_image.ndim,); '
+ f'{offset} was provided.'
+ )
+
+ regions = []
+ if _CYTHON_FIND_OBJECTS:
+ img_uint32 = label_image.astype(np.uint32, copy=False)
+ if label_image.ndim == 2:
+ out = find_all_objects_2D(img_uint32)
+ labels, bboxes = out
+ for i in range(len(labels)):
+ sl = (slice(int(bboxes[i, 0]), int(bboxes[i, 1])),
+ slice(int(bboxes[i, 2]), int(bboxes[i, 3])))
+ regions.append(acdcRegionProperties(
+ sl, int(labels[i]), label_image, intensity_image, cache,
+ spacing=spacing, extra_properties=extra_properties,
+ offset=offset_arr,
+ ))
+ else:
+ out = find_all_objects_3D(img_uint32)
+ labels, bboxes = out
+ for i in range(len(labels)):
+ sl = (slice(int(bboxes[i, 0]), int(bboxes[i, 1])),
+ slice(int(bboxes[i, 2]), int(bboxes[i, 3])),
+ slice(int(bboxes[i, 4]), int(bboxes[i, 5])))
+ regions.append(acdcRegionProperties(
+ sl, int(labels[i]), label_image, intensity_image, cache,
+ spacing=spacing, extra_properties=extra_properties,
+ offset=offset_arr,
+ ))
+ else:
+ objects = ndi.find_objects(label_image)
+ for i, sl in enumerate(objects, start=1):
+ if sl is None:
+ continue
+ regions.append(acdcRegionProperties(
+ sl, i, label_image, intensity_image, cache,
+ spacing=spacing, extra_properties=extra_properties,
+ offset=offset_arr,
+ ))
+ return regions
+
+class acdcRegionProperties(_RegionProperties):
+ def __init__(
+ self,
+ slice,
+ label,
+ label_image,
+ intensity_image,
+ cache_active,
+ *,
+ extra_properties=None,
+ spacing=None,
+ offset=None,
+ ):
+ super().__init__(
+ slice, label, label_image, intensity_image, cache_active,
+ extra_properties=extra_properties, spacing=spacing, offset=offset
+ )
+ # @property
+ # @_cached
+ # def slice(self):
+ # # scale slice with offset
+ # return tuple(
+ # slice(self._slice[i].start + self._offset[i],
+ # self._slice[i].stop + self._offset[i])
+ # for i in range(self._ndim)
+ # )
+
+ @property
+ def image(self):
+ """Return cached object mask from the current label image."""
+ imgage = self._cache.get('image')
+ if imgage is None or not np.any(imgage):
+ self._cache['image'] = self._label_image[self._slice] == self.label
+
+ return self._cache['image']
+
+ @property
+ @_cached
+ def bbox(self):
+ """
+ Returns
+ -------
+ A tuple of the bounding box's start coordinates for each dimension,
+ followed by the end coordinates for each dimension.
+ """
+ return tuple(
+ [self.slice[i].start for i in range(self._ndim)]
+ + [self.slice[i].stop for i in range(self._ndim)]
+ )
+
+ @property
+ @_cached # slow for 3D data, better cache it
+ def centroid(self):
+ return super().centroid
+
+ @property
+ @_cached
+ def contour(self):
+ contours = self._contours_local(retrieve_mode=cv2.RETR_EXTERNAL)
+ if not contours:
+ return np.empty((0, 2), dtype=np.int32)
+
+ contour = max(contours, key=len)
+ contour = np.squeeze(contour, axis=1)
+ contour = np.vstack((contour, contour[0]))
+ return contour + self._xy_offset
+
+ @property
+ @_cached
+ def contour_all(self):
+ # Include both outer boundaries and holes.
+ contours = self._contours_local(retrieve_mode=cv2.RETR_CCOMP)
+ if not contours:
+ return []
+ offset = self._xy_offset
+ return [np.squeeze(cont, axis=1) + offset for cont in contours]
+
+ @property
+ @_cached
+ def _xy_offset(self):
+ if self._ndim != 2:
+ raise AttributeError('contour is only supported for 2D objects.')
+ slc = self.slice
+ return np.array([slc[1].start, slc[0].start], dtype=np.int32)
+
+ def _contours_local(self, retrieve_mode=cv2.RETR_EXTERNAL):
+ if self._ndim != 2:
+ raise AttributeError('contour is only supported for 2D objects.')
+ obj_image = np.ascontiguousarray(self.image, dtype=np.uint8)
+ contours, _ = cv2.findContours(
+ obj_image, retrieve_mode, cv2.CHAIN_APPROX_NONE
+ )
+ return contours
+
+class acdcRegionprops:
+ def __init__(
+ self,
+ lab,
+ acdc_df=None,
+ centroids_loaded=None,
+ IDs_loaded=None,
+ centroids_IDs_exact_loaded=None,
+ ID_to_idx_loaded=None,
+ precache_centroids=True,
+ **kwargs,
+ ):
+ self.lab = lab
+ self.acdc_df = acdc_df
+ self._rp = _acdc_regionprops_factory(lab, **kwargs)
+ self.is3D = self.lab.ndim == 3
+ self._slice_rps = {
+ 'z': {},
+ 'y': {},
+ 'x': {},
+ }
+ self._proj_rps = {
+ 'z': {},
+ 'y': {},
+ 'x': {},
+ }
+ self._proj_labs = {
+ 'z': {},
+ 'y': {},
+ 'x': {},
+ }
+ self._centroid_mapper = {}
+ self._centroid_IDs_exact = set()
+ if IDs_loaded is None or ID_to_idx_loaded is None:
+ self.set_attributes(update_centroid_mapper=False)
+ else:
+ self.ID_to_idx = ID_to_idx_loaded
+ self.IDs_set = set(IDs_loaded)
+ self.IDs = list(self.IDs_set)
+
+ if centroids_IDs_exact_loaded is not None and centroids_loaded is not None:
+ self._centroid_mapper = centroids_loaded
+ self._centroid_IDs_exact = set(centroids_IDs_exact_loaded)
+ elif precache_centroids:
+ self.precache_centroids()
+
+ else:
+ self._centroid_mapper = dict()
+
+ def get_slice_rp(self, slice_number, slicing='z'):
+ if not self.is3D:
+ raise ValueError('Slice-specific regionprops are only supported for 3D labels.')
+
+ slicing = self._normalize_slicing(slicing)
+ slice_number = int(slice_number)
+ self._validate_slice_number(slice_number, slicing)
+
+ rp = self._slice_rps[slicing].get(slice_number)
+ if rp is None:
+ lab_slice = self._get_lab_slice(self.lab, slice_number, slicing)
+ rp = acdcRegionprops(lab_slice, precache_centroids=False)
+ self._slice_rps[slicing][slice_number] = rp
+ return rp
+
+ def get_proj_rp(self, kind='max', slicing='z'):
+ if not self.is3D:
+ raise ValueError('Projection-specific regionprops are only supported for 3D labels.')
+
+ slicing = self._normalize_slicing(slicing)
+ kind = self._normalize_projection_kind(kind)
+
+ rp = self._proj_rps[slicing].get(kind)
+ if rp is None:
+ lab_proj = self._proj_labs[slicing].get(kind)
+ if lab_proj is None:
+ lab_proj = self._get_lab_projection(self.lab, slicing=slicing, kind=kind)
+ self._proj_labs[slicing][kind] = lab_proj
+ rp = acdcRegionprops(lab_proj, precache_centroids=False)
+ self._proj_rps[slicing][kind] = rp
+ return rp
+
+ def get_obj_from_slice_rp(self, ID, slice_number, slicing='z', warn=True):
+ rp = self.get_slice_rp(slice_number, slicing=slicing)
+ return rp.get_obj_from_ID(ID, warn=warn)
+
+ def get_obj_from_proj_rp(self, ID, kind='max', slicing='z', warn=True):
+ kind = self._normalize_projection_kind(kind)
+ rp = self.get_proj_rp(kind=kind, slicing=slicing)
+ return rp.get_obj_from_ID(ID, warn=warn)
+
+ def __iter__(self):
+ return iter(self._rp)
+
+ def __len__(self):
+ return len(self._rp)
+
+ def __getitem__(self, idx):
+ return self._rp[idx]
+
+ def __setitem__(self, idx, value):
+ self._rp[idx] = value
+
+ def __repr__(self):
+ return repr(self._rp)
+
+ def _normalize_slicing(self, slicing):
+ slicing = str(slicing).lower()
+ if slicing not in ('z', 'y', 'x'):
+ raise ValueError(
+ f'Invalid slicing "{slicing}". Valid options are "z", "y", and "x".'
+ )
+ return slicing
+
+ def _slice_axis_index(self, slicing):
+ axis_map = {'z': 0, 'y': 1, 'x': 2}
+ return axis_map[slicing]
+
+ def _normalize_projection_kind(self, kind):
+ kind = str(kind).lower().strip()
+ kind_norm = kind.replace('-', ' ').replace('_', ' ')
+
+ if kind_norm.startswith('max'):
+ return 'max'
+ if kind_norm.startswith('mean'):
+ return 'mean'
+ if kind_norm.startswith('median'):
+ return 'median'
+ if kind_norm.startswith('most common') or kind_norm.startswith('mode'):
+ return 'most_common'
+
+ if kind not in ('max', 'mean', 'median', 'most_common'):
+ raise ValueError(
+ f'Invalid projection kind "{kind}". '
+ 'Valid options are "max", "mean", "median", and "most_common".'
+ )
+ return kind
+
+ def _validate_slice_number(self, slice_number, slicing):
+ axis = self._slice_axis_index(slicing)
+ axis_size = self.lab.shape[axis]
+ if slice_number < 0 or slice_number >= axis_size:
+ raise IndexError(
+ f'Slice number {slice_number} is out of bounds for slicing "{slicing}" '
+ f'with size {axis_size}.'
+ )
+
+ def _has_initialized_slice_rps(self):
+ return any(len(slice_dict) > 0 for slice_dict in self._slice_rps.values())
+
+ def _has_initialized_proj_rps(self):
+ return any(len(proj_dict) > 0 for proj_dict in self._proj_rps.values())
+
+ def _iter_initialized_slice_rps(self):
+ for slicing, slice_dict in self._slice_rps.items():
+ for slice_number, rp in slice_dict.items():
+ yield slicing, slice_number, rp
+
+ def _iter_initialized_proj_rps(self):
+ for slicing, proj_dict in self._proj_rps.items():
+ for kind, rp in proj_dict.items():
+ yield slicing, kind, rp
+
+ def _get_lab_slice(self, lab, slice_number, slicing):
+ if lab.ndim != 3:
+ raise ValueError(
+ f'Slice-specific regionprops are only supported for 3D labels, got {lab.ndim}D.'
+ )
+
+ slicing = self._normalize_slicing(slicing)
+ if slicing == 'z':
+ return lab[slice_number, :, :]
+ if slicing == 'y':
+ return lab[:, slice_number, :]
+ return lab[:, :, slice_number]
+
+ def _get_lab_projection(self, lab, slicing='z', kind='max'):
+ if lab.ndim != 3:
+ raise ValueError(
+ f'Projection-specific regionprops are only supported for 3D labels, got {lab.ndim}D.'
+ )
+
+ axis = self._slice_axis_index(self._normalize_slicing(slicing))
+ kind = self._normalize_projection_kind(kind)
+ if kind == 'max':
+ return np.max(lab, axis=axis)
+
+ if kind == 'most_common':
+ return self._compute_most_common_projection(lab, axis=axis)
+
+ if kind == 'mean':
+ projected = np.mean(lab, axis=axis)
+ else:
+ projected = np.median(lab, axis=axis)
+
+ # Regionprops requires integer labels.
+ return np.rint(projected).astype(lab.dtype, copy=False)
+
+ def _compute_most_common_projection(self, lab, axis):
+ if _CYTHON_MOST_COMMON_PROJECTION:
+ lab_uint32 = lab.astype(np.uint32, copy=False)
+ projected = most_common_projection_3D(lab_uint32, int(axis))
+ return projected.astype(lab.dtype, copy=False)
+
+ projected = _most_common_projection_ignore_zero_numpy(lab, axis)
+ return projected.astype(lab.dtype, copy=False)
+
+ def _projection_canvas_shape(self, slicing):
+ slicing = self._normalize_slicing(slicing)
+ if slicing == 'z':
+ return (self.lab.shape[1], self.lab.shape[2])
+ if slicing == 'y':
+ return (self.lab.shape[0], self.lab.shape[2])
+ return (self.lab.shape[0], self.lab.shape[1])
+
+ def _project_cutout_for_slicing_and_size(self, cutout, obj_id, slicing):
+ axis = self._slice_axis_index(slicing)
+ if _CYTHON_OBJECT_PROJECTIONS:
+ cutout_uint32 = cutout.astype(np.uint32, copy=False)
+ proj, size = object_projection_and_size_3D(
+ cutout_uint32, int(obj_id), int(axis)
+ )
+ return proj.astype(bool, copy=False), int(size)
+
+ proj, size = _object_projection_and_size_numpy(cutout, int(obj_id), int(axis))
+ return proj.astype(bool, copy=False), int(size)
+
+ def _project_object_from_bbox(self, lab, obj_id, bbox, slicing):
+ z0, y0, x0, z1, y1, x1 = [int(v) for v in bbox]
+ if z0 >= z1 or y0 >= y1 or x0 >= x1:
+ return None
+
+ cutout = lab[z0:z1, y0:y1, x0:x1]
+ proj_mask, size = self._project_cutout_for_slicing_and_size(
+ cutout, obj_id, slicing
+ )
+ if size == 0:
+ return None
+
+ if slicing == 'z':
+ out_slice = (slice(y0, y1), slice(x0, x1))
+ elif slicing == 'y':
+ out_slice = (slice(z0, z1), slice(x0, x1))
+ else:
+ out_slice = (slice(z0, z1), slice(y0, y1))
+
+ return out_slice, proj_mask, int(size)
+
+ def _iter_projected_objects_sorted(self, slicing='z', lab=None):
+ if not self.is3D:
+ raise ValueError('Projection helpers are only supported for 3D labels.')
+
+ slicing = self._normalize_slicing(slicing)
+ if lab is None:
+ lab = self.lab
+
+ projected = []
+ for obj in self._rp:
+ if len(obj.bbox) != 6:
+ continue
+ out = self._project_object_from_bbox(lab, obj.label, obj.bbox, slicing)
+ if out is None:
+ continue
+ out_slice, proj_mask, size = out
+ projected.append((int(obj.label), int(size), out_slice, proj_mask))
+
+ # Draw large objects first, then smaller objects so small ones stay visible on top.
+ projected.sort(key=lambda entry: (-entry[1], entry[0]))
+ return projected
+
+ def get_projection_lab_sorted(self, slicing='z', dtype=None):
+ if not self.is3D:
+ raise ValueError('Projection helpers are only supported for 3D labels.')
+
+ slicing = self._normalize_slicing(slicing)
+ out_shape = self._projection_canvas_shape(slicing)
+ if dtype is None:
+ dtype = self.lab.dtype
+
+ lab_proj = np.zeros(out_shape, dtype=dtype)
+ for label, _, out_slice, proj_mask in self._iter_projected_objects_sorted(slicing=slicing):
+ lab_proj[out_slice][proj_mask] = label
+ return lab_proj
+
+ def _get_projection_patch_slices(self, slicing, cutout_bbox):
+ z0, y0, x0, z1, y1, x1 = [int(v) for v in cutout_bbox]
+ if slicing == 'z':
+ return (slice(y0, y1), slice(x0, x1))
+ if slicing == 'y':
+ return (slice(z0, z1), slice(x0, x1))
+ return (slice(z0, z1), slice(y0, y1))
+
+ def _get_projection_patch_lab(self, slicing, cutout_bbox):
+ z0, y0, x0, z1, y1, x1 = [int(v) for v in cutout_bbox]
+ if slicing == 'z':
+ return self.lab[:, y0:y1, x0:x1]
+ if slicing == 'y':
+ return self.lab[z0:z1, :, x0:x1]
+ return self.lab[z0:z1, y0:y1, :]
+
+ def _compute_most_common_projection_patch(self, slicing, cutout_bbox):
+ patch_lab = self._get_projection_patch_lab(slicing, cutout_bbox)
+ axis = self._slice_axis_index(slicing)
+ return self._compute_most_common_projection(patch_lab, axis=axis)
+
+ def _update_cached_most_common_projection_locally(self, slicing, cutout_bbox):
+ lab_proj = self._get_cached_or_new_lab_projection(slicing, 'most_common')
+ patch_slices = self._get_projection_patch_slices(slicing, cutout_bbox)
+ if any(slc.start >= slc.stop for slc in patch_slices):
+ return lab_proj
+
+ patch = self._compute_most_common_projection_patch(slicing, cutout_bbox)
+ lab_proj[patch_slices] = patch
+ self._proj_labs[slicing]['most_common'] = lab_proj
+ return lab_proj
+
+ def _get_cached_or_new_lab_projection(self, slicing, kind):
+ lab_proj = self._proj_labs[slicing].get(kind)
+ if lab_proj is None:
+ lab_proj = self._get_lab_projection(self.lab, slicing=slicing, kind=kind)
+ self._proj_labs[slicing][kind] = lab_proj
+ return lab_proj
+
+ def _replace_cached_lab_projection(self, slicing, kind):
+ lab_proj = self._get_lab_projection(self.lab, slicing=slicing, kind=kind)
+ self._proj_labs[slicing][kind] = lab_proj
+ return lab_proj
+
+ def _apply_assignments_to_lab_from_rp(self, rp, lab, assignments):
+ active_assignments = {
+ int(old_ID): int(new_ID)
+ for old_ID, new_ID in assignments.items()
+ if old_ID != new_ID
+ }
+ if not active_assignments:
+ return lab
+
+ dst = lab.copy()
+ for obj in rp:
+ new_ID = active_assignments.get(obj.label)
+ if new_ID is None:
+ continue
+ dst[obj.slice][obj.image] = new_ID
+ return dst
+
+ def _apply_deletions_to_lab_from_rp(self, rp, lab, IDs_to_delete):
+ IDs_to_delete = set(IDs_to_delete)
+ if not IDs_to_delete:
+ return lab
+
+ dst = lab.copy()
+ for obj in rp:
+ if obj.label in IDs_to_delete:
+ dst[obj.slice][obj.image] = 0
+ return dst
+
+ def _sync_initialized_slice_rps_via_assignments(self, assignments):
+ if not self._has_initialized_slice_rps():
+ return
+
+ for slicing, slice_number, rp in self._iter_initialized_slice_rps():
+ lab_slice = self._get_lab_slice(self.lab, slice_number, slicing)
+ rp.update_regionprops_via_assignments(assignments, lab_slice)
+
+ def _sync_initialized_proj_rps_via_assignments(self, assignments):
+ if not self._has_initialized_proj_rps():
+ return
+
+ for slicing, kind, rp in self._iter_initialized_proj_rps():
+ lab_proj = self._get_cached_or_new_lab_projection(slicing, kind)
+ lab_proj = self._apply_assignments_to_lab_from_rp(
+ rp, lab_proj, assignments
+ )
+ self._proj_labs[slicing][kind] = lab_proj
+ rp.update_regionprops_via_assignments(assignments, lab_proj)
+
+ def _sync_initialized_slice_rps_via_deletions(self, IDs_to_delete):
+ if not self._has_initialized_slice_rps():
+ return
+
+ for _, _, rp in self._iter_initialized_slice_rps():
+ rp.update_regionprops_via_deletions(IDs_to_delete)
+
+ def _sync_initialized_proj_rps_via_deletions(self, IDs_to_delete):
+ if not self._has_initialized_proj_rps():
+ return
+
+ for slicing, kind, rp in self._iter_initialized_proj_rps():
+ lab_proj = self._get_cached_or_new_lab_projection(slicing, kind)
+ lab_proj = self._apply_deletions_to_lab_from_rp(
+ rp, lab_proj, IDs_to_delete
+ )
+ self._proj_labs[slicing][kind] = lab_proj
+ rp.update_regionprops_via_deletions(IDs_to_delete)
+
+ def _sync_initialized_slice_rps_via_update(self, specific_IDs_update_centroids=None):
+ if not self._has_initialized_slice_rps():
+ return
+
+ for slicing, slice_number, rp in self._iter_initialized_slice_rps():
+ lab_slice = self._get_lab_slice(self.lab, slice_number, slicing)
+ rp.update_regionprops(
+ lab_slice,
+ specific_IDs_update_centroids=specific_IDs_update_centroids,
+ )
+
+ def _sync_initialized_proj_rps_via_update(
+ self,
+ specific_IDs_update_centroids=None,
+ cutout_bbox=None,
+ ):
+ if not self._has_initialized_proj_rps():
+ return
+
+ for slicing, kind, rp in self._iter_initialized_proj_rps():
+ if cutout_bbox is not None and kind == 'most_common':
+ lab_proj = self._update_cached_most_common_projection_locally(
+ slicing, cutout_bbox
+ )
+ else:
+ lab_proj = self._replace_cached_lab_projection(slicing, kind)
+ rp.update_regionprops(
+ lab_proj,
+ specific_IDs_update_centroids=specific_IDs_update_centroids,
+ )
+
+ def _normalize_cutout_bbox(self, cutout_bbox):
+ """Normalize cutout_bbox to always be 3D (6 values) for 3D data.
+
+ Automatically expands 2D bbox (4 values: y_start, x_start, y_end, x_end)
+ to 3D bbox (6 values: z_start, y_start, x_start, z_end, y_end, x_end)
+ covering all z-slices.
+ """
+ if self.is3D:
+ if len(cutout_bbox) == 4:
+ # 2D bbox: expand to 3D with full z range
+ y_start, x_start, y_end, x_end = cutout_bbox
+ return (0, y_start, x_start, self.lab.shape[0], y_end, x_end)
+ elif len(cutout_bbox) != 6:
+ raise ValueError(
+ 'For 3D labels, cutout_bbox should have 4 values (2D) or 6 values (3D): '
+ f'got {len(cutout_bbox)}.'
+ )
+ else:
+ if len(cutout_bbox) != 4:
+ raise ValueError(
+ 'For 2D labels, cutout_bbox should have 4 values (y_start, x_start, y_end, x_end), '
+ f'got {len(cutout_bbox)}.'
+ )
+ return cutout_bbox
+
+ def _get_centroid_df_from_df(self):
+ if self.acdc_df is None or len(self.acdc_df) == 0:
+ return {}
+
+ centroid_cols = ['y_centroid', 'x_centroid']
+ if self.is3D and 'z_centroid' in self.acdc_df.columns:
+ centroid_cols = ['z_centroid', 'y_centroid', 'x_centroid']
+
+ if not set(centroid_cols).issubset(self.acdc_df.columns):
+ return {}
+
+ if 'Cell_ID' in self.acdc_df.columns:
+ centroid_df = self.acdc_df.set_index('Cell_ID')[centroid_cols]
+ elif 'ID' in self.acdc_df.columns:
+ centroid_df = self.acdc_df.set_index('ID')[centroid_cols]
+ else:
+ centroid_df = self.acdc_df[centroid_cols]
+
+ return {
+ int(ID): tuple(values)
+ for ID, values in centroid_df.iterrows()
+ }
+
+ def _get_bbox_centers_mapper(
+ self, objs=None, IDs_to_include=None, IDs_to_exclude=None
+ ):
+ if objs is None and not self._rp:
+ return {}
+
+ if objs is None:
+ if IDs_to_include is None:
+ IDs_to_include = (
+ self.IDs_set.difference(IDs_to_exclude)
+ if IDs_to_exclude is not None else self.IDs_set
+ )
+ ids = set(IDs_to_include)
+ objs = [obj for obj in self._rp if obj.label in ids]
+
+ if not objs:
+ return {}
+
+ ndim = 2 if not self.is3D else 3
+ labels = np.empty(len(objs), dtype=int)
+ bboxes = np.empty((len(objs), ndim * 2), dtype=float)
+ for i, obj in enumerate(objs):
+ labels[i] = obj.label
+ bboxes[i] = obj.bbox
+
+ centers = (bboxes[:, :ndim] + bboxes[:, ndim:]) / 2.0
+ return {
+ int(label): tuple(center)
+ for label, center in zip(labels, centers)
+ }
+
+ def precache_centroids(self):
+ centroid_df = self._get_centroid_df_from_df()
+ IDs_from_df = set(centroid_df)
+ IDs_missing_centroid = self.IDs_set.difference(IDs_from_df)
+ bbox_centers_mapper = self._get_bbox_centers_mapper(
+ IDs_to_include=IDs_missing_centroid
+ )
+ self._centroid_mapper = {**bbox_centers_mapper, **centroid_df}
+ self._centroid_IDs_exact = IDs_from_df
+
+ def set_attributes(self, deleted_IDs=None, update_centroid_mapper=True):
+ self.ID_to_idx = {obj.label: idx for idx, obj in enumerate(self._rp)}
+ # Update IDs and IDs_set separately and explicitly
+ self.IDs_set = set(self.ID_to_idx)
+ self.IDs = list(self.IDs_set)
+
+ if not update_centroid_mapper:
+ return
+ if deleted_IDs is not None:
+ for ID in deleted_IDs:
+ self._centroid_mapper.pop(ID, None)
+ self._centroid_IDs_exact.discard(ID)
+ else:
+ self._centroid_mapper = {
+ ID: centroid
+ for ID, centroid in self._centroid_mapper.items()
+ if ID in self.IDs_set
+ }
+ self._centroid_IDs_exact.intersection_update(self.IDs_set)
+
+ def get_obj_from_ID(self, ID, warn=True):
+ idx = self.ID_to_idx.get(ID, None)
+ if idx is not None:
+ return self._rp[idx]
+ else:
+ if warn:
+ # get caller info
+ debugutils.print_call_stack()
+ print(f"Warning: Object with ID {ID} not found in regionprops.")
+ return None
+
+ def delete_IDs(self, IDs_to_delete: set[int], update_other_attrs=True):
+ if not IDs_to_delete:
+ return
+
+ self._rp = [
+ obj for obj in self._rp if obj.label not in IDs_to_delete
+ ]
+
+ if not update_other_attrs:
+ return
+ self.set_attributes(deleted_IDs=IDs_to_delete)
+
+ def _get_IDs_to_update_centroids(
+ self, lab, objs, specific_IDs_update_centroids=None
+ ):
+ if specific_IDs_update_centroids is not None:
+ return set(specific_IDs_update_centroids)
+
+ obj_to_update = set()
+ for obj in objs:
+ has_to_update = False
+ ID = obj.label
+ old_centroid = self._centroid_mapper.get(ID, None)
+ if old_centroid is not None:
+ rounded_centroid = tuple(np.round(old_centroid).astype(int))
+ try:
+ ID_lab = lab[rounded_centroid]
+ except Exception:
+ has_to_update = True
+ else:
+ if ID_lab != ID:
+ has_to_update = True
+ else:
+ has_to_update = True
+
+ if has_to_update:
+ obj_to_update.add(ID)
+
+ return obj_to_update
+
+ def update_regionprops(
+ self, lab, specific_IDs_update_centroids=None,
+ update_centroids=True
+ ):
+ old_rp_by_id = {obj.label: obj for obj in self._rp}
+
+ new_rp = _acdc_regionprops_factory(lab)
+
+ if update_centroids:
+ # Verify that the cached centroid is still inside the object mask.
+ obj_to_update = self._get_IDs_to_update_centroids(
+ lab, new_rp,
+ specific_IDs_update_centroids=specific_IDs_update_centroids
+ )
+
+ bbox_centers_mapper = self._get_bbox_centers_mapper(
+ objs=[obj for obj in new_rp if obj.label in obj_to_update]
+ )
+
+ # update centroids
+ self._centroid_mapper.update(bbox_centers_mapper)
+
+ # remove from exact set if we updated the centroid
+ self._centroid_IDs_exact.difference_update(obj_to_update)
+
+ for obj in new_rp:
+ self._copy_custom_rp_attributes(obj, old_rp_by_id.get(obj.label))
+
+ self._rp = new_rp
+ self.lab = lab
+ self.set_attributes()
+ self._sync_initialized_slice_rps_via_update(
+ specific_IDs_update_centroids=specific_IDs_update_centroids
+ )
+ self._sync_initialized_proj_rps_via_update(
+ specific_IDs_update_centroids=specific_IDs_update_centroids
+ )
+
+ def _copy_custom_rp_attributes(self, new_obj, old_obj):
+ if old_obj is None:
+ return
+ new_obj.dead = getattr(old_obj, 'dead', False)
+ new_obj.excluded = getattr(old_obj, 'excluded', False)
+
+ def _get_bbox_slices(self, bbox, depth_axis=None):
+ ndim = self.lab.ndim
+ if len(bbox) != ndim * 2:
+ raise ValueError(
+ f'Expected a bounding box with {ndim*2} values, '
+ f'got {len(bbox)}.'
+ )
+
+ return tuple(
+ slice(int(bbox[dim]), int(bbox[dim+ndim])) for dim in range(ndim)
+ )
+
+ def _translate_cutout_regionprop(self, obj, offset, lab):
+ offset_arr = np.asarray(offset)
+ centroid = obj.centroid
+ translated_slice = tuple(
+ slice(
+ obj._slice[dim].start + offset_arr[dim],
+ obj._slice[dim].stop + offset_arr[dim],
+ )
+ for dim in range(obj._ndim)
+ )
+ translated_bbox = tuple(
+ [slc.start for slc in translated_slice]
+ + [slc.stop for slc in translated_slice]
+ )
+ translated_centroid = tuple(
+ coord + offset_arr[dim]
+ for dim, coord in enumerate(centroid)
+ )
+
+ obj._label_image = lab
+ obj._slice = translated_slice
+ obj.slice = translated_slice
+ obj._offset = np.zeros_like(offset_arr)
+ obj._cache['slice'] = translated_slice
+ obj._cache['bbox'] = translated_bbox
+ obj._cache['centroid'] = translated_centroid
+ return obj
+
+ def _get_separate_obj_regionprops(self, lab, IDs):
+ IDs = tuple(int(ID) for ID in IDs)
+ if not IDs:
+ return {}
+
+ mask = np.isin(lab, IDs)
+ if not np.any(mask):
+ return {}
+
+ isolated_lab = np.zeros_like(lab)
+ isolated_lab[mask] = lab[mask]
+ return {
+ obj.label: obj
+ for obj in _acdc_regionprops_factory(isolated_lab)
+ if obj.label in IDs
+ }
+
+ def _is_bbox_touching_cutout_border(self, bbox, shape):
+ ndim = len(shape)
+ for dim in range(ndim):
+ if bbox[dim] == 0 or bbox[dim+ndim] == shape[dim]:
+ return True
+ return False
+
+ def _obj_intersects_bbox(self, obj, bbox):
+ ndim = self.lab.ndim
+ obj_bbox = obj.bbox
+ for dim in range(ndim):
+ start = max(int(obj_bbox[dim]), int(bbox[dim]))
+ stop = min(int(obj_bbox[dim+ndim]), int(bbox[dim+ndim]))
+ if start >= stop:
+ return False
+
+ return True
+
+ def _get_old_cutout_IDs_from_rp(self, cutout_bbox):
+ return {
+ obj.label for obj in self._rp
+ if self._obj_intersects_bbox(obj, cutout_bbox)
+ }
+
+ def _set_label_image(self, lab, objs=None, clear_cache=False):
+ if lab is None:
+ return
+
+ self.lab = lab
+ if objs is None:
+ objs = self._rp
+
+ for obj in objs:
+ obj._label_image = lab
+ if clear_cache:
+ obj._cache.clear()
+
+ def update_regionprops_via_assignments(
+ self, assignments:dict[int, int], lab
+ ):
+ """If the lab is completely the same, but only ID changes/swaps have been made
+
+ Parameters
+ ----------
+ assignments : dict[int, int]
+ key: old ID,
+ value: new ID
+ lab : np.ndarray, optional
+ Updated label image. When provided, regionprops objects are rebound
+ to this image so properties such as ``image`` stay consistent after
+ the ID remap.
+ """
+ active_assignments = {
+ int(old_ID): int(new_ID)
+ for old_ID, new_ID in assignments.items()
+ if old_ID in self.IDs_set and old_ID != new_ID
+ }
+ if not active_assignments:
+ self._set_label_image(lab)
+ self._sync_initialized_slice_rps_via_assignments({})
+ self._sync_initialized_proj_rps_via_assignments({})
+ return
+
+ # if not active_assignments:
+ # if lab is not None:
+ # self._set_label_image(lab)
+ # return
+
+ # remapped_IDs = set()
+ # for obj in self._rp:
+ # old_ID = obj.label
+ # new_ID = active_assignments.get(old_ID, old_ID)
+ # if new_ID in remapped_IDs:
+ # raise ValueError(
+ # 'Assignments would create duplicate IDs in regionprops. '
+ # 'Use a full regionprops recomputation for merges.'
+ # )
+ # remapped_IDs.add(new_ID)
+
+ centroid_mapper = {
+ active_assignments.get(ID, ID): centroid
+ for ID, centroid in self._centroid_mapper.items()
+ # if active_assignments.get(ID, ID) in remapped_IDs
+ }
+ centroid_IDs_exact = {
+ active_assignments.get(ID, ID)
+ for ID in self._centroid_IDs_exact
+ # if active_assignments.get(ID, ID) in remapped_IDs
+ }
+
+ # Rebind first so any property access during remap sees the current lab.
+ self._set_label_image(lab)
+
+ for obj in self._rp:
+ old_ID = obj.label
+ new_ID = active_assignments.get(old_ID, old_ID)
+ obj.label = new_ID
+ # if obj.area == 0:
+ # # if area is 0, centroid is not defined and we should not trust the cached one
+ # print("area 0...")
+
+ self._centroid_mapper = centroid_mapper
+ self._centroid_IDs_exact = centroid_IDs_exact
+ self.set_attributes(update_centroid_mapper=False) # update the mapper
+ self._sync_initialized_slice_rps_via_assignments(active_assignments)
+ self._sync_initialized_proj_rps_via_assignments(active_assignments)
+
+ def update_regionprops_via_deletions(
+ self, IDs_to_delete: set[int]
+ ):
+ """If the lab is completely the same, but only some IDs have been deleted
+
+ Parameters
+ ----------
+ IDs_to_delete : set[int]
+ IDs to delete
+ """
+ IDs_to_delete = set(IDs_to_delete).intersection(self.IDs_set)
+ if not IDs_to_delete:
+ return
+ self._rp = [obj for obj in self._rp if obj.label not in IDs_to_delete]
+ self.set_attributes(deleted_IDs=IDs_to_delete) # for updating the IDs to indx, centroid mapper
+ self._sync_initialized_slice_rps_via_deletions(IDs_to_delete)
+ self._sync_initialized_proj_rps_via_deletions(IDs_to_delete)
+
+ def update_regionprops_via_cutout(
+ self, lab, cutout_bbox, specific_IDs=None, depth_axis=None
+ ):
+ """Only relabels the regionprops of a specific cutout.
+ Is only faster for small cutouts. I dont have a number, but I would say
+ less than 30% of total image size.
+
+ Parameters
+ ----------
+ cutout_lab : np.ndarray
+ The labeled cutout image.
+ cutout_bbox : tuple[int, int, int, int]
+ The bounding box of the cutout in the format (min_row, min_col, max_row, max_col).
+ """
+ if specific_IDs is not None and not isinstance(specific_IDs, (list, set, np.ndarray, tuple)):
+ specific_IDs = {specific_IDs}
+ elif specific_IDs is not None:
+ specific_IDs = set(specific_IDs)
+
+ self.lab = lab
+ # Normalize bbox to 3D (expands 2D bbox to full z-range)
+ cutout_bbox = self._normalize_cutout_bbox(cutout_bbox)
+ cutout_slices = self._get_bbox_slices(cutout_bbox, depth_axis=depth_axis)
+ new_cutout = lab[cutout_slices]
+ old_cutout_IDs = self._get_old_cutout_IDs_from_rp(cutout_bbox)
+ rp_cutout_new = _acdc_regionprops_factory(new_cutout)
+ new_cutout_IDs = set(obj.label for obj in rp_cutout_new)
+
+ if not old_cutout_IDs and not new_cutout_IDs:
+ return
+
+ target_IDs = (
+ old_cutout_IDs.union(new_cutout_IDs)
+ if specific_IDs is None
+ else old_cutout_IDs.union(new_cutout_IDs).intersection(specific_IDs)
+ )
+
+ deleted_target_IDs = old_cutout_IDs.difference(new_cutout_IDs).intersection(
+ target_IDs
+ )
+
+ refreshed_IDs = new_cutout_IDs.intersection(target_IDs)
+
+ conflicting_IDs = refreshed_IDs.difference(old_cutout_IDs).intersection(
+ self.IDs_set.difference(old_cutout_IDs)
+ )
+ if conflicting_IDs:
+ raise ValueError(
+ 'Cutout update would reuse IDs that already belong to objects '
+ 'outside the cutout. Use a full regionprops recomputation.'
+ )
+
+ old_rp_by_id = {obj.label: obj for obj in self._rp}
+ IDs_to_replace = old_cutout_IDs.intersection(target_IDs)
+ unaffected_rp = [obj for obj in self._rp if obj.label not in IDs_to_replace]
+
+ offset = tuple(s.start for s in cutout_slices)
+
+ border_touching_IDs = {
+ obj.label
+ for obj in rp_cutout_new
+ if obj.label in refreshed_IDs
+ and self._is_bbox_touching_cutout_border(obj.bbox, new_cutout.shape)
+ }
+ separate_objs = self._get_separate_obj_regionprops(lab, border_touching_IDs)
+
+ new_objs = []
+ updated_centroid_IDs = set()
+ for obj in rp_cutout_new:
+ ID = obj.label
+ if ID not in refreshed_IDs:
+ continue
+ if ID in border_touching_IDs:
+ # edge case: ID changed is outside the cutout
+ new_obj = separate_objs.get(ID)
+ if new_obj is None:
+ continue
+ else:
+ new_obj = self._translate_cutout_regionprop(obj, offset, lab)
+
+ self._copy_custom_rp_attributes(new_obj, old_rp_by_id.get(ID))
+ new_objs.append(new_obj)
+ updated_centroid_IDs.add(ID)
+
+ for ID in deleted_target_IDs:
+ self._centroid_mapper.pop(ID, None)
+ self._centroid_IDs_exact.discard(ID)
+
+ if updated_centroid_IDs:
+ obj_to_update = self._get_IDs_to_update_centroids(
+ lab, new_objs,
+ specific_IDs_update_centroids=updated_centroid_IDs
+ )
+
+ self._centroid_mapper.update(
+ self._get_bbox_centers_mapper(
+ objs=[obj for obj in new_objs if obj.label in obj_to_update]
+ )
+ )
+ self._centroid_IDs_exact.difference_update(obj_to_update)
+
+ self._rp = unaffected_rp + new_objs
+ self._set_label_image(lab)
+ self.set_attributes(update_centroid_mapper=False)
+ self._sync_initialized_slice_rps_via_update(
+ specific_IDs_update_centroids=target_IDs
+ )
+ self._sync_initialized_proj_rps_via_update(
+ specific_IDs_update_centroids=target_IDs,
+ cutout_bbox=cutout_bbox,
+ )
+
+ def get_centroid(self, ID, exact=False):
+ if exact and ID not in self._centroid_IDs_exact:
+ obj = self.get_obj_from_ID(ID)
+ centroid = obj.centroid
+ try:
+ int(centroid[0])
+ except (TypeError, ValueError):
+ print(f"Warning: Centroid for ID {ID} is not a valid coordinate: {centroid}. "
+ f"Object size: {obj.bbox}. Returning None.")
+ return None
+ self._centroid_mapper[ID] = centroid
+ self._centroid_IDs_exact.add(ID)
+ return centroid
+
+ centroid = self._centroid_mapper.get(ID, None)
+ if centroid is None:
+ # add centroid to mapper if not found
+ objs = [self.get_obj_from_ID(ID)]
+ bbox_centers_mapper = self._get_bbox_centers_mapper(objs=objs)
+ self._centroid_mapper.update(bbox_centers_mapper)
+ centroid = self._centroid_mapper.get(ID, None)
+ return centroid
+
+ def copy(self):
+ new_instance = acdcRegionprops(
+ self.lab, precache_centroids=False
+ )
+ new_instance._rp = [obj for obj in self._rp]
+ new_instance._centroid_mapper = self._centroid_mapper.copy()
+ new_instance._centroid_IDs_exact = self._centroid_IDs_exact.copy()
+ for slicing, slice_number, rp in self._iter_initialized_slice_rps():
+ new_instance._slice_rps[slicing][slice_number] = rp.copy()
+ for slicing, kind, rp in self._iter_initialized_proj_rps():
+ new_instance._proj_rps[slicing][kind] = rp.copy()
+ for slicing, proj_labs in self._proj_labs.items():
+ for kind, lab_proj in proj_labs.items():
+ new_instance._proj_labs[slicing][kind] = lab_proj.copy()
+ new_instance.set_attributes(update_centroid_mapper=False)
+ return new_instance
\ No newline at end of file
diff --git a/cellacdc/segm.py b/cellacdc/segm.py
index c75c2589d..293d13ece 100755
--- a/cellacdc/segm.py
+++ b/cellacdc/segm.py
@@ -518,7 +518,7 @@ def main(self):
model_name = win.selectedModel
- if model_name == 'thresholding':
+ if model_name in ('thresholding', 'Automatic thresholding'):
win = apps.QDialogAutomaticThresholding(
parent=self, isSegm3D=self.isSegm3D
)
@@ -528,6 +528,9 @@ def main(self):
return
self.model_kwargs = win.segment_kwargs
+ if model_name == 'Automatic thresholding':
+ model_name = 'thresholding'
+
self.log(f'Downloading {model_name} (if needed)...')
self.downloadWin = apps.downloadModel(model_name, parent=self)
self.downloadWin.download()
diff --git a/cellacdc/segm_utils.py b/cellacdc/segm_utils.py
index 790e59719..880defd9e 100644
--- a/cellacdc/segm_utils.py
+++ b/cellacdc/segm_utils.py
@@ -3,11 +3,10 @@
import time
from .core import segm_model_segment, post_process_segm
from .features import custom_post_process_segm
-from . import io, plot
+from . import io, plot, regionprops
import inspect
-
import os # for dbug
import json # for dbug
@@ -38,6 +37,22 @@ def find_overlap(lab_1, lab_2):
return ID_overlap
+def get_best_overlapping_label(label_img, obj, allowed_labels):
+ allowed_labels = set(allowed_labels)
+ if len(allowed_labels) == 0:
+ return None
+
+ overlapping_labels = label_img[obj.slice][obj.image]
+ if overlapping_labels.size == 0:
+ return None
+
+ overlapping_labels = overlapping_labels[np.isin(overlapping_labels, tuple(allowed_labels))]
+ if overlapping_labels.size == 0:
+ return None
+
+ labels, counts = np.unique(overlapping_labels, return_counts=True)
+ return labels[np.argmax(counts)]
+
def get_obj_from_rps(rps, ID):
for obj in rps:
if obj.label == ID:
@@ -163,10 +178,18 @@ def boxes_overlap(bbox1, bbox2):
# # Use np.unique once on the combined array
# return np.unique(border_labels[border_labels != 0])
-def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
- win, posData, distance_filler_growth=1,
+def single_cell_seg(model, prev_lab, curr_lab, curr_img,
+ IDs, new_unique_ID,
+ posData, distance_filler_growth=1,
overlap_threshold=0.5, padding=0.4,
export_bbox_for_training=False,
+ model_kwargs=None,
+ preproc_recipe=None,
+ applyPostProcessing=False,
+ standardPostProcessKwargs=None,
+ customPostProcessFeatures=None,
+ customPostProcessGroupedFeatures=None,
+ debug=False,
):
"""
Function to segment single cells in the current frame using the previous frame segmentation as a reference.
@@ -178,12 +201,17 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
curr_img: current frame image
IDs: list of IDs of the cells to segment
new_unique_ID: ID to start labeling new cells
- win: from the gui window which sets model params
posData: position data (see rest of acdc)
distance_filler_growth: distance to grow the other IDs to fill the background
overlap_threshold: minimum overlap percentage to consider a cell already segmented
padding: padding around the cell to segment
export_bbox_for_training: if True, export bounding boxes for training model
+ model_kwargs: keyword arguments to pass to the segmentation model
+ preproc_recipe: preprocessing recipe to apply before segmentation
+ applyPostProcessing: if True, apply post-processing to the segmentation
+ standardPostProcessKwargs: keyword arguments for standard post-processing
+ customPostProcessFeatures: custom features for post-processing segmentation
+ customPostProcessGroupedFeatures: custom grouped features for post-processing
Returns:
curr_lab: current frame segmentation with the segmented cells
@@ -194,12 +222,10 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
if export_bbox_for_training:
bboxs_for_debug = []
- model_kwargs = win.model_kwargs
- preproc_recipe = win.preproc_recipe
- applyPostProcessing = win.applyPostProcessing
- standardPostProcessKwargs = win.standardPostProcessKwargs
- customPostProcessFeatures = win.customPostProcessFeatures
- customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures
+ if model_kwargs is None:
+ model_kwargs = {}
+ if standardPostProcessKwargs is None:
+ standardPostProcessKwargs = {}
prev_rp = skimage.measure.regionprops(prev_lab)
prev_lab_shape = prev_lab.shape
@@ -210,23 +236,37 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
assigned_IDs = []
uses_diameter = inspect.signature(model.segment).parameters.get('diameter', None) is not None
- for IDs, bbox in zip(IDs_bboxs, bboxs):
+
+ if debug:
+ imgs_to_show = {
+ i: [] for i in range(len(bboxs))
+ }
+ for i, (IDs, bbox) in enumerate(zip(IDs_bboxs, bboxs)):
box_x_min, box_x_max, box_y_min, box_y_max = bbox
box_curr_img = curr_img[box_x_min:box_x_max, box_y_min:box_y_max].copy()
box_curr_lab = curr_lab[box_x_min:box_x_max, box_y_min:box_y_max]
+
+ if debug:
+ imgs_to_show[i].append(box_curr_img.copy())
+ imgs_to_show[i].append(box_curr_lab.copy())
box_curr_lab_other_IDs = box_curr_lab.copy()
IDs = np.array(IDs)
box_curr_lab_other_IDs[np.isin(box_curr_lab_other_IDs, IDs)] = 0
box_curr_lab_other_IDs_grown = skimage.segmentation.expand_labels(box_curr_lab_other_IDs, distance=distance_filler_growth)
+ if debug:
+ imgs_to_show[i].append(box_curr_lab_other_IDs_grown.copy())
# Fill other IDs with random samples from the background
indices_to_fill = np.where(box_curr_lab_other_IDs_grown != 0)
box_background = box_curr_img[box_curr_lab_other_IDs_grown==0]
random_samples = np.random.choice(box_background, size=indices_to_fill[0].shape, replace=True)
box_curr_img[indices_to_fill] = random_samples
+
+ if debug:
+ imgs_to_show[i].append(box_curr_img.copy())
# Run model, give it the diameter of cell if possible
if uses_diameter:
@@ -248,6 +288,9 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
posData=posData,
)
+ if debug:
+ imgs_to_show[i].append(box_model_lab.copy())
+
if export_bbox_for_training:
bboxs_for_debug.append([IDs, bbox, box_model_lab.copy(), box_curr_lab.copy()])
@@ -269,14 +312,14 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
### maybe add roi extension if cells are deleted...
# Find the overlap between the model segmentation and the other IDs
- overlap = find_overlap(box_model_lab, box_curr_lab_other_IDs)
+ overlaps = find_overlap(box_model_lab, box_curr_lab_other_IDs)
# Set overlapping regions to 0, so already segmented cells are not overwritten
- for ID, overlap_perc in overlap:
- if overlap_perc > overlap_threshold:
- box_model_lab[box_model_lab == ID] = 0
+ IDs_to_filter = [ID for ID, overlap_perc in overlaps if overlap_perc > overlap_threshold]
+ if IDs_to_filter:
+ box_model_lab[np.isin(box_model_lab, IDs_to_filter)] = 0
- rp_model_lab = skimage.measure.regionprops(box_model_lab)
+ rp_model_lab = regionprops.acdcRegionprops(box_model_lab,precache_centroids=False)
for obj in rp_model_lab:
box_curr_lab_other_IDs[box_model_lab == obj.label] = new_unique_ID
assigned_IDs.append(new_unique_ID)
@@ -321,4 +364,6 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID,
with open(json_filepath, 'w') as f:
json.dump(loaded_dict, f, indent=4)
+ if debug:
+ return curr_lab, assigned_IDs, IDs_bboxs, bboxs, imgs_to_show
return curr_lab, assigned_IDs, IDs_bboxs, bboxs
\ No newline at end of file
diff --git a/cellacdc/segmentation.py b/cellacdc/segmentation.py
deleted file mode 100644
index 9bfba56ad..000000000
--- a/cellacdc/segmentation.py
+++ /dev/null
@@ -1,139 +0,0 @@
-import numpy as np
-
-import skimage.segmentation
-import skimage.measure
-
-import cv2
-
-def _find_contours_2D(
- image, bbox_lower_coords=(0, 0), all=False, closed=True
- ):
- mode = cv2.RETR_CCOMP if all else cv2.RETR_EXTERNAL
- contours, _ = cv2.findContours(image, mode, cv2.CHAIN_APPROX_NONE)
-
- if all:
- all_contours = [
- np.squeeze(contour, axis=1)+bbox_lower_coords
- for contour in contours
- ]
- if closed:
- all_contours = [
- np.vstack((contour, contour[0])) for contour in contours
- ]
- return all_contours
- else:
- contour = np.squeeze(contours[0], axis=1)
- if closed:
- contour = np.vstack((contour, contour[0]))
- contour = contour + bbox_lower_coords
- return contour
-
-def find_obj_contour(
- obj: skimage.measure._regionprops.RegionProperties, all=False,
- local=False, do_z_max_proj=False, closed=True
- ):
- is3D = obj.image.ndim == 3
- bbox_y_idx = 1 if is3D else 0
-
- if local:
- bbox_lower_coords=(0, 0)
- else:
- min_y, min_x = obj.bbox[bbox_y_idx:bbox_y_idx+2]
- bbox_lower_coords = (min_x, min_y)
-
- if is3D and do_z_max_proj:
- is3D = False
- obj_image = obj.max(axis=0).astype(np.uint8)
- else:
- obj_image = obj.image.astype(np.uint8)
-
- kwargs = {
- 'bbox_lower_coords': bbox_lower_coords,
- 'all':all, 'closed': closed
- }
- if is3D:
- contours = [
- _find_contours_2D(image_z, **kwargs) for image_z in obj_image
- ]
- else:
- contours = _find_contours_2D(obj_image, **kwargs)
- return contours
-
-def find_contours(
- label_img, connectivity=1, mode='thick', background=0,
- return_coords=False, **kwargs
- ):
- """Return bool array where boundaries between labeled regions are True.
- If `return_coords` is True then return also a list of objects' contours
- coordinates.
-
- Parameters
- ----------
- label_img : (M, N[, P]) ndarray
- An array in which different regions are labeled with either different
- integers or boolean values.
- connectivity : int, optional
- int in {1, ..., `label_img.ndim`}, optional
- A pixel is considered a boundary pixel if any of its neighbors
- has a different label. `connectivity` controls which pixels are
- considered neighbors. A connectivity of 1 (default) means
- pixels sharing an edge (in 2D) or a face (in 3D) will be
- considered neighbors. A connectivity of `label_img.ndim` means
- pixels sharing a corner will be considered neighbors. Default is 1.
- mode : str, optional
- How to mark the boundaries:
- - thick: any pixel not completely surrounded by pixels of the
- same label (defined by `connectivity`) is marked as a boundary.
- This results in boundaries that are 2 pixels thick.
- - inner: outline the pixels *just inside* of objects, leaving
- background pixels untouched.
- - outer: outline pixels in the background around object
- boundaries. When two objects touch, their boundary is also
- marked.
- - subpixel: return a doubled image, with pixels *between* the
- original pixels marked as boundary where appropriate.,
-
- By default 'thick'
- background : int, optional
- For modes 'inner' and 'outer', a definition of a background
- label is required. See `mode` for descriptions of these two,
- by default 0
- return_coords : bool, optional
- If ``True``, also return a list of objects' contours coordinates,
- by default False
- kwargs : dict, optional
- Additional arguments passed `acdctools.segmentation.find_obj_contour`
- function. This function uses the opencv find contours function
- `cv2.findContours`. Used only if `mode='inner'`.
-
- Returns
- -------
- boundaries : ndarray of bool, same shape as `label_img`
- A bool image where ``True`` represents a boundary pixel. For
- `mode` equal to 'subpixel', ``boundaries.shape[i]`` is equal
- to ``2 * label_img.shape[i] - 1`` for all ``i`` (a pixel is
- inserted in between all other pairs of pixels).
- contours_coords: list of ndarray
- A list of ndarrays with shape (N, n) where `n` is the number of
- dimensions of `label_img` and `N` is the number of points in each
- object's contour. The list contains one ndarray per object in
- `label_img`.
- The ordering of columns follows the numpy's order of dimensions
- convention, e.g., for 2-D, the first and second column are the
- y and x coordinates, respectively.
- Only provided if `return_coords` is True.
- """
- boundaries = skimage.segmentation.find_boundaries(
- label_img, connectivity=connectivity, mode=mode, background=background
- )
- if not return_coords:
- return boundaries
-
- is2D = label_img.ndim == 2
- rp = skimage.measure.regionprops(label_img)
- contours_coords = []
- for obj in rp:
- if mode == 'inner' and is2D:
- pass
- else:
- pass
diff --git a/cellacdc/trackers/CellACDC/CellACDC_tracker.py b/cellacdc/trackers/CellACDC/CellACDC_tracker.py
index 07cfe841b..c02bddb99 100755
--- a/cellacdc/trackers/CellACDC/CellACDC_tracker.py
+++ b/cellacdc/trackers/CellACDC/CellACDC_tracker.py
@@ -5,82 +5,137 @@
import numpy as np
from skimage.measure import regionprops
+from cellacdc.regionprops import acdcRegionprops
from skimage.segmentation import relabel_sequential
-from cellacdc import core, printl
+from cellacdc import core, printl, debugutils
+
+try:
+ from cellacdc.precompiled.precompiled_functions import (
+ calc_IoA_matrix_2D as _calc_IoA_matrix_2D_cython,
+ calc_IoA_matrix_3D as _calc_IoA_matrix_3D_cython,
+ )
+ _HAS_CYTHON_IOA = True
+ print('tracking: imported precompiled IoA helpers.')
+except ImportError:
+ _HAS_CYTHON_IOA = False
+ print('[WARNING]: tracking could not import precompiled IoA helpers, falling back to NumPy implementation.')
DEBUG = False
+def _normalize_specific_IDs(specific_IDs):
+ if specific_IDs is None:
+ return None
+ if isinstance(specific_IDs, (list, tuple, set, np.ndarray)):
+ return set(specific_IDs)
+ return {specific_IDs}
+
+def _filter_subset_assignments(old_IDs, tracked_IDs, all_curr_IDs, specific_IDs):
+ if specific_IDs is None:
+ return old_IDs, tracked_IDs
+
+ selected_curr_IDs = set(specific_IDs)
+ other_curr_IDs = set(all_curr_IDs).difference(selected_curr_IDs)
+ filtered_old_IDs = []
+ filtered_tracked_IDs = []
+ for old_ID, tracked_ID in zip(old_IDs, tracked_IDs):
+ if tracked_ID in other_curr_IDs:
+ continue
+ filtered_old_IDs.append(old_ID)
+ filtered_tracked_IDs.append(tracked_ID)
+
+ return filtered_old_IDs, filtered_tracked_IDs
+
def calc_Io_matrix(lab, prev_lab, rp, prev_rp, IDs_curr_untracked=None,
- denom:str='area_prev', IDs=None):
- # maybe its faster to calculate IoU not via mask but via area1 / (area1 + area2 - intersection)
- IDs_prev = []
- if IDs_curr_untracked is None:
+ specific_IDs=None,
+ denom:str='area_prev'):
+ specific_IDs = _normalize_specific_IDs(specific_IDs)
+ if IDs_curr_untracked is None and isinstance(rp, acdcRegionprops):
+ IDs_curr_untracked = rp.IDs
+ elif IDs_curr_untracked is None:
IDs_curr_untracked = [obj.label for obj in rp]
+ elif not isinstance(IDs_curr_untracked, list):
+ IDs_curr_untracked = list(IDs_curr_untracked)
- IoA_matrix = np.zeros((len(rp), len(prev_rp)))
- rp_mapper = {obj.label: obj for obj in rp}
- idx_mapper = {obj.label: i for i, obj in enumerate(rp)}
- # For each ID in previous frame get IoA with all current IDs
- # Rows: IDs in current frame, columns: IDs in previous frame
-
- ### just an idea for having area_curr as a denom possibility: just switch all around...
- # if denom == 'area_curr':
- # # switch prev with curr
- # prev_lab_temp, prev_rp_temp = prev_lab.copy(), prev_rp.copy()
- # prev_lab, prev_rp = lab.copy, rp.copy()
- # lab, rp = prev_lab_temp, prev_rp_temp
-
- if not denom in ['area_prev', 'union']:
+ if specific_IDs is not None:
+ IDs_curr_untracked = [
+ ID for ID in IDs_curr_untracked if ID in specific_IDs
+ ]
+
+ if isinstance(prev_rp, acdcRegionprops):
+ IDs_prev = prev_rp.IDs
+
+ else:
+ IDs_prev = [obj.label for obj in prev_rp]
+
+ if not IDs_curr_untracked:
+ return np.zeros((0, len(prev_rp))), IDs_curr_untracked, IDs_prev
+
+ if denom not in ('area_prev', 'union'):
raise ValueError(
"Invalid denom value. Use 'area_prev' or 'union'."
)
- # prev_label_positions = {ID_prev: np.where(prev_lab == ID_prev)[0] for ID_prev in set(prev_lab) if ID_prev != 0}
- # if denom == 'union':
- # temp_lab = np.zeros(lab.shape, dtype=bool)
- for j, obj_prev in enumerate(prev_rp):
- ID_prev = obj_prev.label
- IDs_prev.append(ID_prev)
- # if IDs is not None and ID_prev not in IDs:
- # continue
+ if _HAS_CYTHON_IOA:
+ use_union = denom == 'union'
+ curr_IDs_arr = np.array(IDs_curr_untracked, dtype=np.uint32)
+ prev_IDs_arr = np.array(IDs_prev, dtype=np.uint32)
+ prev_areas_arr = np.array([obj.area for obj in prev_rp], dtype=np.uint32)
+ if use_union:
+ rp_mapper = {obj.label: obj for obj in rp}
+ curr_areas_arr = np.array(
+ [rp_mapper[ID].area for ID in IDs_curr_untracked], dtype=np.uint32
+ )
+ else:
+ curr_areas_arr = np.empty(0, dtype=np.uint32)
+ lab_u32 = np.asarray(lab, dtype=np.uint32)
+ prev_lab_u32 = np.asarray(prev_lab, dtype=np.uint32)
+ if lab.ndim == 2:
+ IoA_matrix = _calc_IoA_matrix_2D_cython(
+ lab_u32, prev_lab_u32, curr_IDs_arr, prev_IDs_arr,
+ prev_areas_arr, curr_areas_arr, use_union,
+ )
+ else:
+ IoA_matrix = _calc_IoA_matrix_3D_cython(
+ lab_u32, prev_lab_u32, curr_IDs_arr, prev_IDs_arr,
+ prev_areas_arr, curr_areas_arr, use_union,
+ )
+ return IoA_matrix, IDs_curr_untracked, IDs_prev
- if denom == 'area_prev': # or denom == 'area_curr':
+ # --- pure-Python fallback (used when Cython extension is not compiled) ---
+ IoA_matrix = np.zeros((len(IDs_curr_untracked), len(prev_rp)))
+ rp_mapper = {obj.label: obj for obj in rp}
+ idx_mapper = {ID: i for i, ID in enumerate(IDs_curr_untracked)}
+ for j, obj_prev in enumerate(prev_rp):
+ if denom == 'area_prev':
denom_val = obj_prev.area
-
- # Get intersecting IDs between current and object in previous frame
intersect_IDs, intersects = np.unique(
lab[obj_prev.slice][obj_prev.image], return_counts=True
)
for intersect_ID, I in zip(intersect_IDs, intersects):
- if intersect_ID == 0:
- continue
-
- if I == 0:
+ if intersect_ID == 0 or I == 0:
continue
-
if denom == 'union':
+ if intersect_ID not in rp_mapper:
+ continue
obj_curr = rp_mapper[intersect_ID]
- # temp_lab[obj_prev.slice][obj_prev.image] = True
- # temp_lab[obj_curr.slice][obj_curr.image] = True
- # denom_val = np.count_nonzero(temp_lab)
- # temp_lab[:] = False
denom_val = obj_prev.area + obj_curr.area - I
if denom_val == 0:
continue
-
- idx = idx_mapper[intersect_ID]
- IoA = I/denom_val
- IoA_matrix[idx, j] = IoA
+ idx = idx_mapper.get(intersect_ID)
+ if idx is None:
+ continue
+ IoA_matrix[idx, j] = I / denom_val
return IoA_matrix, IDs_curr_untracked, IDs_prev
def assign(
IoA_matrix, IDs_curr_untracked, IDs_prev, IoA_thresh=0.4,
aggr_track=None, IoA_thresh_aggr=0.4, daughters_list=None,
- IDs=None):
+ specific_IDs=None):
# Determine max IoA between IDs and assign tracked ID if IoA >= IoA_thresh
if IoA_matrix.size == 0:
return [], []
+
max_IoA_col_idx = IoA_matrix.argmax(axis=1)
unique_col_idx, counts = np.unique(max_IoA_col_idx, return_counts=True)
counts_dict = dict(zip(unique_col_idx, counts))
@@ -187,7 +242,10 @@ def indexAssignment(
remove_untracked=False,
assign_unique_new_IDs=True,
return_assignments=False,
- IDs=None
+ dont_return_tracked_lab=False,
+ specific_IDs=None,
+ all_curr_IDs=None,
+ IDs=None,
):
"""Replace `old_IDs` in `lab` with `tracked_IDs` while making sure to
avoid merging IDs.
@@ -229,15 +287,25 @@ def indexAssignment(
assignments: dict
Returned only if `return_assignments` is True.
"""
+ specific_IDs = _normalize_specific_IDs(specific_IDs)
log_debugging(
'start',
IDs_curr_untracked=IDs_curr_untracked,
old_IDs=old_IDs
)
- # Replace untracked IDs with tracked IDs and new IDs with increasing num
+ if all_curr_IDs is None:
+ all_curr_IDs = list(IDs_curr_untracked)
+ old_IDs, tracked_IDs = _filter_subset_assignments(
+ old_IDs, tracked_IDs, all_curr_IDs, specific_IDs
+ )
+
+ # Replace untracked IDs with tracked IDs and new IDs with increasing num.
+ # When tracking only a subset of current IDs, leave unrelated labels untouched.
new_untracked_IDs = [ID for ID in IDs_curr_untracked if ID not in old_IDs]
- tracked_lab = lab
+
+ if not dont_return_tracked_lab:
+ tracked_lab = lab
assignments = {}
log_debugging(
'assign_unique',
@@ -251,9 +319,10 @@ def indexAssignment(
new_tracked_IDs = [
uniqueID+i for i in range(len(new_untracked_IDs))
]
- core.lab_replace_values(
- tracked_lab, rp, new_untracked_IDs, new_tracked_IDs
- )
+ if not dont_return_tracked_lab:
+ core.lab_replace_values(
+ tracked_lab, rp, new_untracked_IDs, new_tracked_IDs
+ )
assignments.update(dict(zip(new_untracked_IDs, new_tracked_IDs)))
log_debugging(
'new_untracked_and_assign_unique',
@@ -271,9 +340,10 @@ def indexAssignment(
new_tracked_IDs = [
uniqueID+i for i in range(len(new_IDs_in_trackedIDs))
]
- core.lab_replace_values(
- tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs
- )
+ if not dont_return_tracked_lab:
+ core.lab_replace_values(
+ tracked_lab, rp, new_IDs_in_trackedIDs, new_tracked_IDs
+ )
assignments.update(dict(zip(new_IDs_in_trackedIDs, new_tracked_IDs)))
log_debugging(
'new_untracked_and_tracked',
@@ -283,10 +353,15 @@ def indexAssignment(
new_tracked_IDs=new_tracked_IDs
)
if tracked_IDs:
- core.lab_replace_values(
- tracked_lab, rp, old_IDs, tracked_IDs, in_place=True
- )
- assignments.update(dict(zip(old_IDs, tracked_IDs)))
+ if not dont_return_tracked_lab:
+ core.lab_replace_values(
+ tracked_lab, rp, old_IDs, tracked_IDs, in_place=True
+ )
+ assignments.update({
+ old_ID: tracked_ID
+ for old_ID, tracked_ID in zip(old_IDs, tracked_IDs)
+ if old_ID != tracked_ID
+ })
log_debugging(
'tracked',
tracked_IDs=tracked_IDs,
@@ -295,6 +370,8 @@ def indexAssignment(
if not return_assignments:
return tracked_lab
+ elif dont_return_tracked_lab:
+ return assignments
else:
return tracked_lab, assignments
@@ -305,17 +382,30 @@ def track_frame(
return_all=False, aggr_track=None, IoA_matrix=None,
IoA_thresh_aggr=None, IDs_prev=None, return_prev_IDs=False,
mother_daughters=None, denom_overlap_matrix = 'area_prev',
- IDs=None
+ return_assignments=False, specific_IDs=None, dont_return_tracked_lab=False
):
if not np.any(lab):
# Skip empty frames
return lab
+ all_curr_IDs = (
+ list(IDs_curr_untracked)
+ if IDs_curr_untracked is not None else None
+ )
+ if isinstance(rp, acdcRegionprops) and all_curr_IDs is None:
+ all_curr_IDs = rp.IDs
+ elif all_curr_IDs is None:
+ all_curr_IDs = [obj.label for obj in rp]
+ elif not isinstance(all_curr_IDs, list):
+ all_curr_IDs = list(all_curr_IDs)
+
if IoA_matrix is None:
- IoA_matrix, IDs_curr_untracked, IDs_prev = calc_Io_matrix(
+ IoA_matrix, tracked_curr_IDs, IDs_prev = calc_Io_matrix(
lab, prev_lab, rp, prev_rp, IDs_curr_untracked=IDs_curr_untracked,
- denom=denom_overlap_matrix, IDs=IDs
+ denom=denom_overlap_matrix,specific_IDs=specific_IDs,
)
+ else:
+ tracked_curr_IDs = IDs_curr_untracked
daughters_list = []
if mother_daughters:
@@ -323,38 +413,61 @@ def track_frame(
daughters_list.extend(daughters)
old_IDs, tracked_IDs = assign(
- IoA_matrix, IDs_curr_untracked, IDs_prev,
+ IoA_matrix, tracked_curr_IDs, IDs_prev,
IoA_thresh=IoA_thresh, aggr_track=aggr_track,
IoA_thresh_aggr=IoA_thresh_aggr, daughters_list=daughters_list,
+ specific_IDs=specific_IDs,
)
if posData is None and unique_ID is None:
unique_ID = max(
- (max(IDs_prev, default=0), max(IDs_curr_untracked, default=0))
+ (max(IDs_prev, default=0), max(all_curr_IDs, default=0))
) + 1
elif unique_ID is None:
# Compute starting unique ID
setBrushID_func(useCurrentLab=True)
unique_ID = posData.brushID+1
- if not return_all:
+ if not return_all and not return_assignments:
tracked_lab = indexAssignment(
- old_IDs, tracked_IDs, IDs_curr_untracked,
+ old_IDs, tracked_IDs, tracked_curr_IDs,
lab.copy(), rp, unique_ID,
assign_unique_new_IDs=assign_unique_new_IDs,
+ specific_IDs=specific_IDs,
+ all_curr_IDs=all_curr_IDs,
+ )
+ elif dont_return_tracked_lab:
+ assignments = indexAssignment(
+ old_IDs, tracked_IDs, tracked_curr_IDs,
+ lab.copy(), rp, unique_ID,
+ assign_unique_new_IDs=assign_unique_new_IDs,
+ return_assignments=True, specific_IDs=specific_IDs,
+ dont_return_tracked_lab=True,
+ all_curr_IDs=all_curr_IDs,
)
else:
tracked_lab, assignments = indexAssignment(
- old_IDs, tracked_IDs, IDs_curr_untracked,
+ old_IDs, tracked_IDs, tracked_curr_IDs,
lab.copy(), rp, unique_ID,
assign_unique_new_IDs=assign_unique_new_IDs,
- return_assignments=return_all,
+ return_assignments=True, specific_IDs=specific_IDs,
+ all_curr_IDs=all_curr_IDs,
)
# old_new_ids = dict(zip(old_IDs, tracked_IDs)) # for now not used, but could be useful in the future
- if return_all:
+ if return_all and dont_return_tracked_lab:
+ # special case where we want to only get the assignments but need the rest too!
+ return IoA_matrix, assignments, tracked_IDs
+ elif return_all:
return tracked_lab, IoA_matrix, assignments, tracked_IDs # remove tracked_IDs and change code in CellACDC_tracker.py if causing problems
+ elif dont_return_tracked_lab:
+ return assignments
+ elif return_assignments:
+ add_info = {
+ 'assignments': assignments,
+ }
+ return tracked_lab, add_info
else:
return tracked_lab
diff --git a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py
index 4bef3ec46..198e2d6ac 100644
--- a/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py
+++ b/cellacdc/trackers/CellACDC_2steps/CellACDC_2steps_tracker.py
@@ -13,6 +13,26 @@
import cellacdc.core
from ..CellACDC import CellACDC_tracker
+from ..CellACDC.CellACDC_tracker import _normalize_specific_IDs
+
+from cellacdc._types import NotGUIParam
+
+def _format_tracking_result(
+ tracked_lab,
+ assignments,
+ to_track_tracked_objs_2nd_step,
+ return_assignments,
+ dont_return_tracked_lab,
+ ):
+ add_info = {
+ 'assignments': assignments,
+ 'to_track_tracked_objs_2nd_step': to_track_tracked_objs_2nd_step,
+ }
+
+ if dont_return_tracked_lab:
+ return add_info
+
+ return tracked_lab, add_info # always return extra info
class SearchRangeUnits:
values = ['pixels', 'micrometre']
@@ -114,11 +134,14 @@ def track_frame(
overlap_threshold=0.4,
search_range_unit: SearchRangeUnits='pixels',
lost_IDs_search_range=10,
- unique_ID: Integer=None
+ unique_ID: Integer=None,
+ specific_IDs: NotGUIParam=None,
+ dont_return_tracked_lab=False,
+ return_assignments=False,
):
"""Track two consecutive frames in two steps. First step based on
`overlap_threshold` and second step tracks only lost objects to new
- objects detemined at first step.
+ objects determined at first step.
Parameters
----------
@@ -148,20 +171,30 @@ def track_frame(
If not None, uses this as starting ID for all the untracked objects.
If None, this will be calculated based on the two input frames.
"""
+ specific_IDs = _normalize_specific_IDs(specific_IDs)
to_track_tracked_objs_2nd_step = None
prev_rp = skimage.measure.regionprops(prev_frame_lab)
curr_rp = skimage.measure.regionprops(current_frame_lab)
- tracked_lab_1st_step = CellACDC_tracker.track_frame(
+ tracked_lab_1st_step, add_info = CellACDC_tracker.track_frame(
prev_frame_lab,
prev_rp,
current_frame_lab,
curr_rp,
IoA_thresh=overlap_threshold,
- return_prev_IDs=False,
- unique_ID=unique_ID
+ return_prev_IDs=False,
+ unique_ID=unique_ID,
+ specific_IDs=specific_IDs,
+ return_assignments=True,
)
+ assignments_step_1 = add_info['assignments']
+ selected_tracked_IDs = None
+ if specific_IDs is not None:
+ selected_tracked_IDs = {
+ assignments_step_1.get(curr_ID, curr_ID)
+ for curr_ID in specific_IDs
+ }
prev_rp_mapper = {obj.label: obj for obj in prev_rp}
@@ -176,27 +209,43 @@ def track_frame(
}
if not lost_rp_mapper:
- return tracked_lab_1st_step, to_track_tracked_objs_2nd_step
+ return _format_tracking_result(
+ tracked_lab_1st_step,
+ assignments_step_1,
+ to_track_tracked_objs_2nd_step,
+ return_assignments,
+ dont_return_tracked_lab,
+ )
new_rp_mapper = {
obj.label: obj for obj in tracked_rp_1st_step
+ if (
+ selected_tracked_IDs is None
+ or obj.label in selected_tracked_IDs
+ )
if prev_rp_mapper.get(obj.label) is None
}
-
+
if not new_rp_mapper:
- return tracked_lab_1st_step, to_track_tracked_objs_2nd_step
+ return _format_tracking_result(
+ tracked_lab_1st_step,
+ assignments_step_1,
+ to_track_tracked_objs_2nd_step,
+ return_assignments,
+ dont_return_tracked_lab,
+ )
ndim = current_frame_lab.ndim
lost_IDs_coords = np.zeros((len(lost_rp_mapper), ndim))
lost_IDs_idx_to_obj_mapper = {}
for lost_idx, lost_obj in enumerate(lost_rp_mapper.values()):
- lost_IDs_coords[lost_idx] = lost_obj.centroid
+ lost_IDs_coords[lost_idx] = lost_obj.centroid # we have overwritten RP so its always cached
lost_IDs_idx_to_obj_mapper[lost_idx] = lost_obj
new_IDs_coords = np.zeros((len(new_rp_mapper), ndim))
new_IDs_idx_to_obj_mapper = {}
for new_idx, new_obj in enumerate(new_rp_mapper.values()):
- new_IDs_coords[new_idx] = new_obj.centroid
+ new_IDs_coords[new_idx] = new_obj.centroid # we have overwritten RP so its always cached
new_IDs_idx_to_obj_mapper[new_idx] = new_obj
if search_range_unit == 'micrometre':
@@ -229,21 +278,45 @@ def track_frame(
tracked_objs_2nd_step.append(lost_IDs_idx_to_obj_mapper[i])
if not IDs_to_track:
- return tracked_lab_1st_step, to_track_tracked_objs_2nd_step
+ return _format_tracking_result(
+ tracked_lab_1st_step,
+ assignments_step_1,
+ to_track_tracked_objs_2nd_step,
+ return_assignments,
+ dont_return_tracked_lab,
+ )
- tracked_lab_2nd_step = cellacdc.core.lab_replace_values(
- tracked_lab_1st_step,
- tracked_rp_1st_step,
- IDs_to_track,
- tracked_IDs_2nd_step
- )
+ if not dont_return_tracked_lab:
+ tracked_lab_2nd_step = cellacdc.core.lab_replace_values(
+ tracked_lab_1st_step,
+ tracked_rp_1st_step,
+ IDs_to_track,
+ tracked_IDs_2nd_step
+ )
+ else:
+ tracked_lab_2nd_step = None
if self._annot_obj_2nd_step:
to_track_tracked_objs_2nd_step = (
objs_to_track, tracked_objs_2nd_step
)
- return tracked_lab_2nd_step, to_track_tracked_objs_2nd_step
+ assignments_step_2 = dict(zip(IDs_to_track, tracked_IDs_2nd_step))
+ for current_ID, tracked_ID in list(assignments_step_1.items()):
+ final_tracked_ID = assignments_step_2.get(tracked_ID)
+ if final_tracked_ID is not None:
+ assignments_step_1[current_ID] = final_tracked_ID
+
+ for current_ID, tracked_ID in assignments_step_2.items():
+ assignments_step_1.setdefault(current_ID, tracked_ID)
+
+ return _format_tracking_result(
+ tracked_lab_2nd_step,
+ assignments_step_1,
+ to_track_tracked_objs_2nd_step,
+ return_assignments,
+ dont_return_tracked_lab,
+ )
def updateGuiProgressBar(self, signals):
if signals is None:
diff --git a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py
index 8e47215a9..560376dbd 100644
--- a/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py
+++ b/cellacdc/trackers/CellACDC_normal_division/CellACDC_normal_division_tracker.py
@@ -4,33 +4,12 @@
from cellacdc.core import getBaseCca_df, printl
from cellacdc.myutils import checked_reset_index, checked_reset_index_Cell_ID
import numpy as np
-from skimage.measure import regionprops
from tqdm import tqdm
import pandas as pd
-from cellacdc.myutils import exec_time
from cellacdc._types import NotGUIParam
import copy
import cellacdc.debugutils as debugutils
-
-# def filter_cols(df):
-# """
-# Filters the columns of a DataFrame based on a predefined set of column names.
-# 'generation_num_tree', 'root_ID_tree', 'sister_ID_tree', 'parent_ID_tree', 'parent_ID_tree', 'emerg_frame_i', 'division_frame_i'
-# plus any column that starts with 'sister_ID_tree'
-
-# Parameters:
-# - df (pandas.DataFrame): The input DataFrame.
-
-# Returns:
-# - pandas.DataFrame: The filtered DataFrame containing only the specified columns.
-# """
-# lin_tree_cols = {'generation_num_tree', 'root_ID_tree',
-# 'sister_ID_tree', 'parent_ID_tree',
-# 'parent_ID_tree', 'emerg_frame_i',
-# 'division_frame_i', 'is_history_known'}
-# sis_cols = {col for col in df.columns if col.startswith('sister_ID_tree')}
-# lin_tree_cols = lin_tree_cols | sis_cols
-# return df[list(lin_tree_cols)]
+from cellacdc.regionprops import acdcRegionprops as acdcRegionprops
def reorg_sister_cells_for_export(lineage_tree_frame):
"""
@@ -60,45 +39,6 @@ def reorg_sister_cells_for_export(lineage_tree_frame):
return lineage_tree_frame
-# def reorg_sister_cells_inner_func(row):
-# """
-# Reorganizes the sister cells in a row of a DataFrame. Used as an inner function for apply.
-
-# Parameters:
-# - row (pandas.Series): The input row of the DataFrame (alredy filtered for the sister columns).
-# Returns:
-# - pandas.Series: The reorganized row with the sister cells.
-# """
-
-# values = [int(i) for i in row if i not in {0, -1} and not np.isnan(i)] or [-1]
-# values = list(set(values))
-# return values
-
-
-# def reorg_sister_cells_for_import(df):
-# """
-# Reorganizes the sister cells for import.
-
-# This function takes a DataFrame `df` as input and performs the following steps:
-# 1. Identifies the sister columns in the DataFrame.
-# 2. Removes any values that are equal to 0 or -1 from the sister columns. (Which both represent no sister cell)
-# 3. Converts the remaining values in the sister columns to a set.
-# 4. Converts the set of values to a list if it is not empty, otherwise assigns [-1] to the sister column. (It actually shouldn't be empty, but just in case...)
-# 5. Removes the sister columns from the DataFrame. And adds the list as the new 'sister_ID_tree' column.
-
-# Parameters:
-# - df (pandas.DataFrame): The input DataFrame.
-
-# Returns:
-# - df (pandas.DataFrame): The modified DataFrame with reorganized sister cells.
-# """
-# sister_cols = [col for col in df.columns if col.startswith('sister_ID_tree')] # handling sister columns
-# df.loc[:, 'sister_ID_tree'] = df[sister_cols].apply(reorg_sister_cells_inner_func, axis=1)
-# sister_cols.remove('sister_ID_tree')
-# df = df.drop(columns=sister_cols)
-# df = checked_reset_index_Cell_ID(df)
-# return df
-
def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh_instant=None):
"""
Identifies cells that have not undergone division based on the input IoA matrix.
@@ -151,13 +91,8 @@ def mother_daughter_assign(IoA_matrix, IoA_thresh_daughter, min_daughter, max_da
else:
should_remove_idx.append(False)
- # printl(f'length of mother_daughters: {len(mother_daughters), len(should_remove_idx)}')
mother_daughters = [mother_daughters[i] for i, remove in enumerate(should_remove_idx) if not remove]
- # daughters_li = []
- # for _, daughters in mother_daughters:
- # daughters_li.extend(daughters)
-
return aggr_track, mother_daughters
def added_lineage_tree_to_cca_df(added_lineage_tree):
@@ -241,33 +176,6 @@ def IoA_index_daughter_to_ID(daughters, assignments, IDs_curr_untracked):
return daughter_IDs
-# def update_fam_dynamically(families, fixed_df, Cell_IDs_fixed=None):
-# if Cell_IDs_fixed is None:
-# Cell_IDs_fixed = fixed_df.index
-# for idx, family in enumerate(families):
-# # Keep only cellinfos where cell_id is in Cell_IDs_fixed
-# families[idx] = [cellinfo for cellinfo in family if cellinfo[0] not in Cell_IDs_fixed]
-
-# families = [family for family in families if family] # Remove empty families
-# handled_cells = set()
-# for family in families:
-# root_ID = family[0][0] # The first cell in the family is the root
-# try:
-# relevant_cells = fixed_df.loc[fixed_df['root_ID_tree'] == root_ID]
-# except:
-# printl(fixed_df['root_ID_tree'])
-# for relevant_cell in relevant_cells.index:
-# # Update the family with the generation number and root ID
-# family.append((relevant_cell, relevant_cells.loc[relevant_cell, 'generation_num_tree']))
-# handled_cells.update(relevant_cells.index)
-
-# for cell_id in Cell_IDs_fixed:
-# if cell_id not in handled_cells:
-# # If the cell is not handled, create a new family for it
-# families.append([(cell_id, fixed_df.loc[cell_id, 'generation_num_tree'])])
-
-# return families
-
class normal_division_tracker:
"""
A class that tracks cell divisions in a video sequence. The tracker uses the Intersection over Area (IoA) metric to track cells and identify daughter cells.
@@ -323,7 +231,8 @@ def __init__(self,
self.tracked_video[0] = segm_video[0]
def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None,
- IDs=None, unique_ID=None):
+ IDs=None, unique_ID=None,
+ return_assignments=False, specific_IDs=None, dont_return_tracked_lab=False):
"""
Tracks a single frame in the video sequence.
@@ -342,42 +251,75 @@ def track_frame(self, frame_i, lab=None, prev_lab=None, rp=None, prev_rp=None,
prev_lab = self.tracked_video[frame_i-1]
if rp is None:
- self.rp = regionprops(lab.copy())
+ self.rp = acdcRegionprops(lab.copy(), precache_centroids=False)
else:
self.rp = rp
if prev_rp is None:
- prev_rp = regionprops(prev_lab.copy())
-
- IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab,
- prev_lab,
- self.rp,
- prev_rp,
- IDs=IDs,
- )
- self.aggr_track, self.mother_daughters = mother_daughter_assign(IoA_matrix,
- IoA_thresh_daughter=self.IoA_thresh_daughter,
- min_daughter=self.min_daughter,
- max_daughter=self.max_daughter,
- IoA_thresh_instant=self.IoA_thresh
- )
- self.tracked_lab, IoA_matrix, self.assignments, _ = track_frame_base(prev_lab,
- prev_rp,
- lab,
- self.rp,
- IoA_thresh=self.IoA_thresh,
- IoA_matrix=IoA_matrix,
- aggr_track=self.aggr_track,
- IoA_thresh_aggr=self.IoA_thresh_aggressive,
- IDs_curr_untracked=self.IDs_curr_untracked,
- IDs_prev=self.IDs_prev,
- return_all=True,
- mother_daughters=self.mother_daughters,
- unique_ID=unique_ID
- )
+ prev_rp = acdcRegionprops(prev_lab.copy(), precache_centroids=False)
+
+ full_IoA_matrix, full_curr_IDs, self.IDs_prev = calc_Io_matrix(
+ lab,
+ prev_lab,
+ self.rp,
+ prev_rp,
+ )
+ IoA_matrix, self.IDs_curr_untracked, _ = calc_Io_matrix(
+ lab,
+ prev_lab,
+ self.rp,
+ prev_rp,
+ specific_IDs=specific_IDs,
+ )
+ full_aggr_track, full_mother_daughters = mother_daughter_assign(
+ full_IoA_matrix,
+ IoA_thresh_daughter=self.IoA_thresh_daughter,
+ min_daughter=self.min_daughter,
+ max_daughter=self.max_daughter,
+ IoA_thresh_instant=self.IoA_thresh,
+ )
+
+ subset_idx_mapper = {
+ curr_ID: idx for idx, curr_ID in enumerate(self.IDs_curr_untracked)
+ }
+ self.aggr_track = [
+ subset_idx_mapper[full_curr_IDs[idx]]
+ for idx in full_aggr_track
+ if full_curr_IDs[idx] in subset_idx_mapper
+ ]
+ self.mother_daughters = []
+ for mother_idx, daughter_idxs in full_mother_daughters:
+ subset_daughter_idxs = [
+ subset_idx_mapper[full_curr_IDs[idx]]
+ for idx in daughter_idxs
+ if full_curr_IDs[idx] in subset_idx_mapper
+ ]
+ if subset_daughter_idxs:
+ self.mother_daughters.append((mother_idx, subset_daughter_idxs))
-
- self.tracked_video[frame_i] = self.tracked_lab
+ out = track_frame_base(
+ prev_lab,
+ prev_rp,
+ lab,
+ self.rp,
+ IoA_thresh=self.IoA_thresh,
+ IoA_matrix=IoA_matrix,
+ aggr_track=self.aggr_track,
+ IoA_thresh_aggr=self.IoA_thresh_aggressive,
+ IDs_curr_untracked=self.IDs_curr_untracked,
+ IDs_prev=self.IDs_prev,
+ return_all=True,
+ mother_daughters=self.mother_daughters,
+ unique_ID=unique_ID,
+ specific_IDs=specific_IDs,
+ return_assignments=return_assignments,
+ dont_return_tracked_lab=dont_return_tracked_lab,
+ )
+ if dont_return_tracked_lab:
+ IoA_matrix, self.assignments, self.tracked_IDs = out
+ else:
+ self.tracked_lab, IoA_matrix, self.assignments, self.tracked_IDs = out
+ self.tracked_video[frame_i] = self.tracked_lab
class normal_division_lineage_tree:
"""
@@ -594,8 +536,8 @@ def init_lineage_tree(self, lab=None, first_df=None, frame_i=None):
if lab is not None:
- rp = regionprops(lab)
- labels = [obj.label for obj in rp]
+ rp = acdcRegionprops(lab, precache_centroids=False)
+ labels = rp.IDs
cca_df = pd.DataFrame({
'Cell_ID': labels,
})
@@ -730,10 +672,10 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None):
None
"""
if rp is None:
- rp = regionprops(lab)
+ rp = acdcRegionprops(lab, precache_centroids=False)
if prev_rp is None:
- prev_rp = regionprops(prev_lab)
+ prev_rp = acdcRegionprops(prev_lab, precache_centroids=False)
IoA_matrix, self.IDs_curr_untracked, self.IDs_prev = calc_Io_matrix(lab, prev_lab, rp, prev_rp)
@@ -751,7 +693,7 @@ def real_time(self, frame_i, lab, prev_lab, rp=None, prev_rp=None):
self.mother_daughters = filtered_mother_daughters
curr_IDs = set(self.IDs_curr_untracked)
- prev_IDs = {obj.label for obj in prev_rp}
+ prev_IDs = set(prev_rp.IDs)
new_IDs = curr_IDs - prev_IDs
self.frames_for_dfs.add(frame_i)
self.add_new_frame(frame_i, self.mother_daughters, self.IDs_prev, self.IDs_curr_untracked, None, curr_IDs, new_IDs)
@@ -842,80 +784,6 @@ def update_df_li_locally(self, df, frame_i):
df_data.loc[ID] = cell_row
- # This will probably be made obsolete by the gui_mode version
- # def insert_lineage_df(self, lineage_df, frame_i, update_fams=True,
- # consider_children=True, raw_input=False, propagate=True,
- # relevant_cells=None):
- # """
- # Insert or replace a lineage DataFrame at a given frame index, optionally updating families and propagating changes.
-
- # Args:
- # lineage_df (pd.DataFrame): The lineage DataFrame to insert.
- # frame_i (int): The index of the frame.
- # update_fams (bool, optional): If True, update families based on the changes. Defaults to True.
- # consider_children (bool, optional): If True, update children of the inserted frame. Defaults to True.
-
- # Returns:
- # None
- # """
- # if not self.gui_mode:
- # printl("here")
- # if not raw_input:
- # lineage_df = reorg_sister_cells_for_import(lineage_df)
- # lineage_df = filter_cols(lineage_df)
-
- # lineage_df = checked_reset_index_Cell_ID(lineage_df)
- # len_lineage_list = len(self.lineage_list)
- # if frame_i == len_lineage_list:
- # self.lineage_list.append(lineage_df)
- # self.frames_for_dfs.add(frame_i)
-
- # self.update_df_li_locally(lineage_df, frame_i)
-
- # if propagate:
- # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i,
- # consider_children=consider_children, Cell_IDs_fixed=relevant_cells,
- # families=self.families if update_fams else None)
- # if update_fams:
- # self.lineage_list, self.families = out
- # else:
- # self.lineage_list = out
-
- # elif frame_i < len_lineage_list:
- # self.lineage_list[frame_i] = lineage_df
- # self.update_df_li_locally(lineage_df, frame_i)
- # if propagate:
- # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i,
- # consider_children=consider_children, Cell_IDs_fixed=relevant_cells,
- # families=self.families if update_fams else None)
- # if update_fams:
- # self.lineage_list, self.families = out
- # else:
- # self.lineage_list = out
-
-
- # elif frame_i > len_lineage_list:
- # printl(f'WARNING: Frame_i {frame_i} was inserted. The lineage list was only {len(self.lineage_list)} frames long, so the last known lineage tree was copy pasted up to frame_i {frame_i}')
-
- # original_length = len(self.lineage_list)
- # self.lineage_list = self.lineage_list + [self.lineage_list[-1]] * (frame_i - len(self.lineage_list))
-
- # self.generate_gen_df_from_df_li(self.lineage_list, force=True)
-
- # self.lineage_list.append(lineage_df)
-
- # frame_is = set(range(len(self.lineage_list)-original_length))
- # self.frames_for_dfs = self.frames_for_dfs | frame_is
-
- # self.update_df_li_locally(lineage_df, frame_i)
- # if propagate:
- # out = update_consistency(df_li=self.lineage_list, fixed_frame_i=frame_i,
- # consider_children=consider_children, Cell_IDs_fixed=relevant_cells,
- # families=self.families if update_fams else None)
- # if update_fams:
- # self.lineage_list, self.families = out
- # else:
- # self.lineage_list = out
def _update_consistency(self, fixed_frame_i=None, fixed_df=None,
Cell_IDs_fixed=None, consider_children=True):
@@ -1043,63 +911,6 @@ def propagate(self, frame_i, relevant_cells=None):
self._update_consistency(fixed_frame_i=frame_i,
consider_children=True, Cell_IDs_fixed=relevant_cells)
- # This will probably be made obsolete by the gui_mode version
- # def load_lineage_df_list(self, df_li):
- # """
- # Load a list of lineage DataFrames, reconstructing the lineage tree and families.
-
- # Args:
- # df_li (list): List of acdc_df DataFrames.
-
- # Returns:
- # None
- # """
- # df_li = copy.deepcopy(df_li) # Ensure we don't modify the original list
- # # Support for first_frame was removed since it is not necessary, just make the df_li correct...
- # # Also the tree needs to be init before. Also if df_li does not contain any relevant dfs, nothing happens
- # print('Loading lineage data...')
- # df_li_new = []
- # families = []
- # families_root_IDs = []
- # added_IDs = set()
-
- # for i, df in enumerate(df_li):
- # if df is None:
- # continue
-
- # if 'generation_num_tree' not in df.columns:
- # continue
-
- # mask = (df['generation_num_tree'].isnull() |
- # df["generation_num_tree"].isna())
-
- # if mask.any() or df["generation_num_tree"].empty:
- # continue
-
- # df = checked_reset_index_Cell_ID(df)
-
- # df = filter_cols(df)
- # df = reorg_sister_cells_for_import(df)
- # self.frames_for_dfs.add(i)
- # df_li_new.append(df)
-
- # df_filter = df.index.isin(added_IDs)
- # for root_ID, group in df[df_filter].groupby('root_ID_tree'):
- # if root_ID not in families_root_IDs:
- # family = list(zip(group.index, group['generation_num_tree']))
- # families.append(family)
- # families_root_IDs.append(root_ID)
- # else:
- # # If the root_ID is already in families, we just update the family with the new cells
- # family_index = families_root_IDs.index(root_ID)
- # families[family_index].extend(zip(group.index, group['generation_num_tree']))
-
- # added_IDs.update(group.index)
-
- # if df_li_new:
- # self.lineage_list = df_li_new
-
- # This will probably be made obsolete by the gui_mode version
def export_df(self, frame_i):
"""
Export the lineage DataFrame for a specific frame, cleaning up auxiliary columns.
@@ -1256,8 +1067,8 @@ def track(self,
IoA_thresh_daughter=IoA_thresh_daughter
)
pbar.update()
- rp = regionprops(segm_video[0])
- prev_IDs = {obj.label for obj in rp}
+ rp = acdcRegionprops(segm_video[0], precache_centroids=False)
+ prev_IDs = rp.IDs_set
prev_rp = rp
continue
@@ -1270,8 +1081,8 @@ def track(self,
IDs_prev = tracker.IDs_prev
assignments = tracker.assignments
IDs_curr_untracked = tracker.IDs_curr_untracked
- rp = regionprops(tracker.tracked_lab)
- curr_IDs = {obj.label for obj in rp}
+ rp = acdcRegionprops(tracker.tracked_lab)
+ curr_IDs = rp.IDs_set
new_IDs = curr_IDs - prev_IDs
if record_lineage or return_tracked_lost_centroids:
tree.add_new_frame(
@@ -1289,7 +1100,7 @@ def track(self,
found = True
break
if not found:
- labels = [obj.label for obj in rp]
+ labels = rp.IDs
printl(mother, mother_ID, IDs_curr_untracked, labels)
raise ValueError('Something went wrong with the tracked lost centroids.')
@@ -1328,6 +1139,9 @@ def track_frame(self,
min_daughter:int = 2,
max_daughter:int = 2,
unique_ID: NotGUIParam =None,
+ return_assignments: NotGUIParam =False,
+ specific_IDs: NotGUIParam =None,
+ dont_return_tracked_lab: NotGUIParam =False,
):
"""
Tracks cell division in a single frame. (This is used for real time tracking in the GUI)
@@ -1352,14 +1166,32 @@ def track_frame(self,
segm_video = [previous_frame_labels, current_frame_labels]
tracker = normal_division_tracker(segm_video, IoA_thresh_daughter, min_daughter, max_daughter, IoA_thresh, IoA_thresh_aggressive)
- tracker.track_frame(1, IDs=IDs, unique_ID=unique_ID)
- tracked_video = tracker.tracked_video
+ tracker.track_frame(
+ 1,
+ IDs=IDs,
+ unique_ID=unique_ID,
+ return_assignments=return_assignments,
+ specific_IDs=specific_IDs,
+ dont_return_tracked_lab=dont_return_tracked_lab,
+ )
mother_daughters_pairs = tracker.mother_daughters
IDs_prev = tracker.IDs_prev
mothers = {IDs_prev[pair[0]] for pair in mother_daughters_pairs}
+ assignments = tracker.assignments
+
+ if dont_return_tracked_lab:
+ return assignments
+
+ tracked_lab = tracker.tracked_video[-1]
+ if not return_assignments:
+ return tracked_lab
- return tracked_video[-1], mothers
+ add_info = {
+ 'mothers': mothers,
+ 'assignments': assignments
+ }
+ return tracked_lab, add_info
def updateGuiProgressBar(self, signals):
"""
diff --git a/cellacdc/whitelist.py b/cellacdc/whitelist.py
index 0621f4514..c2211b96e 100644
--- a/cellacdc/whitelist.py
+++ b/cellacdc/whitelist.py
@@ -1,7 +1,7 @@
import os
import numpy as np
import skimage.measure
-from . import printl, myutils
+from . import printl, myutils, regionprops
import json
from typing import Set, List, Tuple
import time
@@ -222,14 +222,14 @@ def create_new_centroids(self,
new_IDs = self.originalLabsIDs[i] - self.originalLabsIDs[i-1]
- rp = None
if frame_i==i and curr_rp is not None:
rp = curr_rp
else:
- rp = skimage.measure.regionprops(self.originalLabs[i])
+ rp = regionprops.acdcRegionprops(self.originalLabs[i],
+ precache_centroids=False)
self.new_centroids.append({
- tuple(map(int, obj.centroid)) for obj in rp if obj.label in new_IDs
+ tuple(map(int, rp.get_centroid(label))) for label in new_IDs
})
@@ -411,7 +411,7 @@ def IDsAccepted(self,
printl('Using curr_lab')
IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)}
else:
- IDs_curr = allData_li[frame_i]['IDs']
+ IDs_curr = allData_li[frame_i]['regionprops'].IDs_set
if self._debug:
printl('Using allData_li')
@@ -488,7 +488,7 @@ def makeOriginalLabsAndIDs(self, segm_data: np.ndarray,
IDs = set(IDs_curr)
elif allData_li is not None:
try:
- IDs = set(allData_li[i]['IDs'])
+ IDs = allData_li[i]['regionprops'].IDs_set
except KeyError:
pass
if IDs is None:
@@ -746,7 +746,7 @@ def propagateIDs(self,
printl('Using index_lab_combo')
IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)}
elif curr_rp is not None:
- IDs_curr = {obj.label for obj in curr_rp}
+ IDs_curr = curr_rp.IDs_set
if self._debug:
printl('Using rp')
elif curr_lab is not None:
@@ -755,7 +755,7 @@ def propagateIDs(self,
printl('Using curr_lab')
IDs_curr = {obj.label for obj in skimage.measure.regionprops(lab)}
else:
- IDs_curr = allData_li[frame_i]['IDs']
+ IDs_curr = allData_li[frame_i]['regionprops'].IDs_set
if self._debug:
printl('Using allData_li')
@@ -872,7 +872,7 @@ def propagateIDs(self,
if frame_i == i:
IDs_curr_loc = IDs_curr
else:
- IDs_curr_loc = set(allData_li[i]['IDs'])
+ IDs_curr_loc =allData_li[i]['regionprops'].IDs_set
new_whitelist = self.get(i, try_create_new_whitelists).copy()
old_whitelist = new_whitelist.copy()
@@ -939,12 +939,12 @@ def whitelistTrackOGagainstPreviousFrame_cb(self, signal_slot=None):
if not self.whitelistCheckOriginalLabels():
return
old_cell_IDs = posData.whitelist.originalLabsIDs[frame_i]
- prev_cell_IDs = posData.allData_li[frame_i-1]['IDs']
+ prev_cell_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set
self.whitelistTrackOGCurr(against_prev=True)
new_cell_IDs = posData.whitelist.originalLabsIDs[frame_i]
new_IDs = new_cell_IDs - old_cell_IDs
- new_IDs = new_IDs & set(prev_cell_IDs)
+ new_IDs = new_IDs & prev_cell_IDs
self.whitelistUpdateLab(
track_og_curr=False, IDs_to_add=new_IDs,
@@ -1066,7 +1066,7 @@ def whitelistViewOGIDs(self, checked:bool):
self.store_data(autosave=False)
if frame_i > 0:
- missing_IDs = set(posData.IDs) - set(posData.allData_li[frame_i-1]['IDs'])
+ missing_IDs = posData.IDs_set - posData.allData_li[frame_i-1]['regionprops'].IDs_set
self.trackManuallyAddedObject(missing_IDs,isNewID=True, wl_update=False)
self.setAllTextAnnotations()
@@ -1502,7 +1502,7 @@ def whitelistTrackOGCurr(self, frame_i:int=None,
### against what should I track?
if lab is not None and not rp:
- rp = skimage.measure.regionprops(lab)
+ rp = regionprops.acdcRegionprops(lab, precache_centroids=False)
changed_frame = False
if lab is None:
@@ -1520,7 +1520,7 @@ def whitelistTrackOGCurr(self, frame_i:int=None,
rp = posData.rp
lab = posData.lab
og_lab = posData.whitelist.originalLabs[frame_i]
- og_rp = skimage.measure.regionprops(og_lab)
+ og_rp = regionprops.acdcRegionprops(og_lab, precache_centroids=False)
# lab = lab.copy()
denom_overlap_matrix = 'union' if not against_prev else 'area_prev'
@@ -1530,7 +1530,6 @@ def whitelistTrackOGCurr(self, frame_i:int=None,
denom_overlap_matrix=denom_overlap_matrix,
posData = posData,
setBrushID_func=self.setBrushID,
- IDs=IDs,
# assign_unique_new_IDs=False,
)
@@ -1583,7 +1582,7 @@ def whitelistTrackCurrOG(self, frame_i:int=None, against_prev:bool=False):
else:
og_lab = posData.whitelist.originalLabs[frame_i]
- og_rp = skimage.measure.regionprops(og_lab)
+ og_rp = regionprops.acdcRegionprops(og_lab, precache_centroids=False)
denom_overlap_matrix = 'union' if not against_prev else 'area_prev'
diff --git a/cellacdc/widgets.py b/cellacdc/widgets.py
index b22e1cec4..bb247e504 100755
--- a/cellacdc/widgets.py
+++ b/cellacdc/widgets.py
@@ -71,6 +71,7 @@
from . import _core, core
from . import QtScoped
from . import prompts
+from . import fonts
from .acdc_regex import float_regex
from .config import PREPROCESS_MAPPER
from . import _base_widgets
@@ -84,8 +85,7 @@
PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR = _palettes.QProgressBarHighlightedTextColor()
TEXT_COLOR = _palettes.text_float_rgba()
-font = QFont()
-font.setPixelSize(12)
+font = fonts.font
custom_cmaps_filepath = os.path.join(settings_folderpath, 'custom_colormaps.ini')
@@ -1075,7 +1075,7 @@ class _ReorderableListModel(QAbstractListModel):
def __init__(self, items, parent=None):
QAbstractItemModel.__init__(self, parent)
- self.nodes = items
+ self.nodes = list(items)
self.lastDroppedItems = []
self.pendingRemoveRowsAfterDrop = False
@@ -1286,7 +1286,7 @@ def __init__(
self.setStyleSheet(styleSheet)
def setItems(self, items):
- self._model.nodes = items
+ self._model.nodes = list(items)
def items(self):
return self._model.nodes
@@ -1464,10 +1464,9 @@ def warnSelectionEmpty(self):
def ok_cb(self, checked=False):
self.clickedButton = self.sender()
- self.cancel = False
selectedItems = self.listBox.selectedItems()
- self.selectedItemsText = [item.text() for item in selectedItems]
- if not self.allowSingleSelection and len(self.selectedItemsText) < 2:
+ selectedItemsText = [item.text() for item in selectedItems]
+ if not self.allowSingleSelection and len(selectedItemsText) < 2:
msg = myMessageBox(wrapText=False, showCentered=False)
txt = html_utils.paragraph(
'You need to select two or more items.
'
@@ -1477,9 +1476,12 @@ def ok_cb(self, checked=False):
msg.warning(self, 'Select two or more items', txt)
return
- if not self.allowEmptySelection and not self.selectedItemsText:
+ if not self.allowEmptySelection and not selectedItemsText:
self.warnSelectionEmpty()
return
+
+ self.cancel = False
+ self.selectedItemsText = selectedItemsText
self.sigSelectionConfirmed.emit(self.selectedItemsText)
self.close()
@@ -12083,4 +12085,305 @@ def closeEvent(self, event):
if self.screenShotWin is not None:
self.screenShotWin.close()
- return super().closeEvent(event)
\ No newline at end of file
+ return super().closeEvent(event)
+
+
+class MultiPickListWidget(QWidget):
+ """Generic list widget with multi-pick (repeated-selection) support.
+
+ Each pickable row shows ``- count +`` controls. Left-clicking adds
+ one instance; right-clicking or Ctrl+left-click removes one. The same
+ item can appear multiple times in :attr:`selectionSequence`.
+
+ Parameters
+ ----------
+ items:
+ Initial list of item labels.
+ excludedItems:
+ Labels that are shown in the list but *not* given +/- controls
+ (e.g. placeholder entries like "Add custom model…"). Click events
+ on these are silently ignored.
+ parent:
+ Optional parent widget.
+ """
+
+ sigSelectionChanged = Signal(list) # emits selectionSequence on every change
+
+ def __init__(self, items=None, excludedItems=None, parent=None):
+ super().__init__(parent)
+
+ self._excludedItems = set(excludedItems or [])
+ self._itemsMap = {} # label → QListWidgetItem
+ self._countMap = defaultdict(int)
+ self._countLabelMap = {}
+ self.selectionSequence = []
+
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ self.listBox = listWidget(isMultipleSelection=False)
+ self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
+
+ for label in (items or []):
+ self._addListItem(label)
+
+ if self._itemsMap:
+ self.listBox.setCurrentRow(0)
+
+ self.listBox.itemClicked.connect(self._onItemClicked)
+ self.listBox.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu)
+ self.listBox.customContextMenuRequested.connect(self._onRightClick)
+
+ # self.listBox.setStyleSheet(LISTWIDGET_STYLESHEET)
+ # self.setStyleSheet(LISTWIDGET_STYLESHEET)
+ layout.addWidget(self.listBox)
+
+
+ @property
+ def itemsMap(self):
+ """Dict mapping label → QListWidgetItem for all pickable items."""
+ return dict(self._itemsMap)
+
+ def currentItemName(self):
+ """Return the label of the currently highlighted item, or ``None``."""
+ item = self.listBox.currentItem()
+ return item.text() if item is not None else None
+
+ def addSelection(self, label):
+ """Add one instance of *label* to the selection."""
+ if label not in self._itemsMap:
+ return
+ self.selectionSequence.append(label)
+ self._countMap[label] += 1
+ self._updateCountLabel(label)
+ self.listBox.setCurrentItem(self._itemsMap[label])
+ self.sigSelectionChanged.emit(list(self.selectionSequence))
+
+ def removeSelection(self, label):
+ """Remove the last instance of *label* from the selection."""
+ if self._countMap.get(label, 0) <= 0:
+ return
+ for i in range(len(self.selectionSequence) - 1, -1, -1):
+ if self.selectionSequence[i] == label:
+ self.selectionSequence.pop(i)
+ break
+ self._countMap[label] = max(0, self._countMap[label] - 1)
+ self._updateCountLabel(label)
+ self.sigSelectionChanged.emit(list(self.selectionSequence))
+
+ def resetSelection(self):
+ """Clear all selections and reset all counters to zero."""
+ self.selectionSequence = []
+ self._countMap = defaultdict(int)
+ for label in self._countLabelMap:
+ self._updateCountLabel(label)
+ self.sigSelectionChanged.emit([])
+
+ def setSelectionFromList(self, labels):
+ """Set the selection to *labels* (duplicates supported)."""
+ self.resetSelection()
+ for label in labels:
+ self.addSelection(label)
+
+ def registerItem(self, label, insertBeforeLabel=None):
+ """Dynamically add a new pickable item.
+
+ Parameters
+ ----------
+ label:
+ Text for the new item.
+ insertBeforeLabel:
+ If given, insert the new item immediately before this label.
+ If not found or not given, the item is appended.
+
+ Returns the created ``QListWidgetItem``.
+ """
+ if label in self._itemsMap:
+ return self._itemsMap[label]
+
+ item = QListWidgetItem(label)
+
+ if insertBeforeLabel is not None:
+ target = self._itemsMap.get(insertBeforeLabel)
+ if target is None:
+ for row in range(self.listBox.count()):
+ row_item = self.listBox.item(row)
+ if row_item is not None and row_item.text() == insertBeforeLabel:
+ target = row_item
+ break
+ if target is not None:
+ row = self.listBox.row(target)
+ self.listBox.insertItem(row, item)
+ else:
+ self.listBox.addItem(item)
+ else:
+ self.listBox.addItem(item)
+
+ self._itemsMap[label] = item
+ self._addCounterWidget(label, item)
+ return item
+
+ def _addListItem(self, label):
+ """Create a QListWidgetItem and, if pickable, attach a counter widget."""
+ item = QListWidgetItem(label)
+ self.listBox.addItem(item)
+ if label not in self._excludedItems:
+ self._itemsMap[label] = item
+ self._addCounterWidget(label, item)
+
+ def _addCounterWidget(self, label, item):
+ rowWidget = QWidget()
+ rowLayout = QHBoxLayout(rowWidget)
+ rowLayout.setContentsMargins(4, 0, 4, 0)
+ rowLayout.setSpacing(6)
+
+ nameLabelPlaceholder = QSpacerItem(2, 0)
+ minusBtn = QPushButton('-')
+ plusBtn = QPushButton('+')
+ countLabel = QLabel(str(self._countMap.get(label, 0)))
+
+ minusBtn.setFixedWidth(24)
+ plusBtn.setFixedWidth(24)
+ countLabel.setMinimumWidth(20)
+ countLabel.setAlignment(Qt.AlignCenter)
+
+ minusBtn.clicked.connect(lambda _, lbl=label: self.removeSelection(lbl))
+ plusBtn.clicked.connect(lambda _, lbl=label: self.addSelection(lbl))
+
+ rowLayout.addItem(nameLabelPlaceholder)
+ rowLayout.addStretch(1)
+ rowLayout.addWidget(minusBtn)
+ rowLayout.addWidget(countLabel)
+ rowLayout.addWidget(plusBtn)
+
+ self._countLabelMap[label] = countLabel
+ self._updateCountLabel(label)
+ self.listBox.setItemWidget(item, rowWidget)
+
+ def _updateCountLabel(self, label):
+ lbl = self._countLabelMap.get(label)
+ if lbl is not None:
+ count = self._countMap.get(label, 0)
+ lbl.setText(str(count))
+ if count <= 0:
+ lbl.setStyleSheet('color: gray;')
+ else:
+ lbl.setStyleSheet('')
+
+ def _onItemClicked(self, item):
+ label = item.text()
+ if label in self._excludedItems:
+ return
+ modifiers = QApplication.keyboardModifiers()
+ if modifiers & Qt.ControlModifier:
+ self.removeSelection(label)
+ else:
+ self.addSelection(label)
+
+ def _onRightClick(self, pos):
+ item = self.listBox.itemAt(pos)
+ if item is None:
+ return
+ label = item.text()
+ if label in self._excludedItems:
+ return
+ self.removeSelection(label)
+
+
+class ModelSelectionWidget(QWidget):
+ """List widget for selecting segmentation models.
+
+ Thin wrapper around :class:`MultiPickListWidget` that populates the list
+ with the installed models and adds a special "Add custom model…" entry.
+
+ ``sigSelectionChanged`` and ``selectionSequence`` are proxied from the
+ underlying :class:`MultiPickListWidget`.
+ """
+
+ _ADD_CUSTOM = 'Add custom model...'
+
+ sigSelectionChanged = Signal(list)
+
+ def __init__(self, parent=None, customFirst='', allowMultiSelection=False):
+ super().__init__(parent)
+
+ self.allowMultiSelection = allowMultiSelection
+
+ models = myutils.get_list_of_models()
+ if customFirst:
+ try:
+ models.insert(0, models.pop(models.index(customFirst)))
+ except ValueError:
+ pass
+
+ items = models + [self._ADD_CUSTOM]
+
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ if allowMultiSelection:
+ items = models
+ self._picker = MultiPickListWidget(
+ items=items,
+ excludedItems=[self._ADD_CUSTOM],
+ parent=self,
+ )
+ self._picker.listBox.setFont(font)
+ self._picker.sigSelectionChanged.connect(self.sigSelectionChanged)
+ self.listBox = self._picker.listBox
+ layout.addWidget(self._picker)
+ else:
+ self.listBox = listWidget(isMultipleSelection=False)
+ self.listBox.setFont(font)
+ self.listBox.addItems(models)
+ add_item = QListWidgetItem(self._ADD_CUSTOM)
+ add_item.setFont(fonts.italicFont)
+ self.listBox.addItem(add_item)
+ self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
+ self.listBox.setCurrentRow(0)
+ self._picker = None
+ layout.addWidget(self.listBox)
+
+ # ------------------------------------------------------------------
+ # Proxy helpers (multi-selection mode only)
+ # ------------------------------------------------------------------
+
+ @property
+ def selectionSequence(self):
+ return self._picker.selectionSequence if self._picker is not None else []
+
+ @property
+ def modelItemsMap(self):
+ return self._picker.itemsMap if self._picker is not None else {}
+
+ def currentModelName(self):
+ if self._picker is not None:
+ return self._picker.currentItemName()
+ item = self.listBox.currentItem()
+ return item.text() if item is not None else None
+
+ def addModelSelection(self, name):
+ if self._picker is not None:
+ self._picker.addSelection(name)
+
+ def removeModelSelection(self, name):
+ if self._picker is not None:
+ self._picker.removeSelection(name)
+
+ def resetSelectionSequence(self):
+ if self._picker is not None:
+ self._picker.resetSelection()
+
+ def setSelectionFromList(self, models):
+ if self._picker is not None:
+ self._picker.setSelectionFromList(models)
+
+ def registerCustomModel(self, model_name):
+ """Add a newly registered custom model and return its item."""
+ if self._picker is not None:
+ return self._picker.registerItem(
+ model_name, insertBeforeLabel=self._ADD_CUSTOM
+ )
+ item = QListWidgetItem(model_name)
+ self.listBox.insertItem(self.listBox.count() - 1, item)
+ return item
diff --git a/cellacdc/workers.py b/cellacdc/workers.py
index 25feb6f72..cb21ae732 100755
--- a/cellacdc/workers.py
+++ b/cellacdc/workers.py
@@ -43,7 +43,7 @@
from . import cli
from .utils import resize
from . import segm_utils
-
+from . import regionprops
DEBUG = False
def worker_exception_handler(func):
@@ -178,7 +178,7 @@ def run(self):
for frame_i, data_dict in enumerate(self.posData.allData_li):
lab = data_dict['labels']
rp = data_dict['regionprops']
- IDs = data_dict['IDs']
+ IDs = data_dict['regionprops'].IDs
if lab is None:
lab = self.posData.segm_data[frame_i]
rp = skimage.measure.regionprops(lab)
@@ -198,14 +198,10 @@ def run(self):
class SegForLostIDsWorker(QObject):
sigAskInit = Signal()
- sigAskInstallModel = Signal(str)
sigshowImageDebug = Signal(object)
sigStoreData = Signal(bool)
sigUpdateRP = Signal(bool, bool)
- # sigGetData = Signal()
- # sigGet2Dlab = Signal()
- # sigGetTrackedLostIDs = Signal()
- # sigGetBrushID = Signal()
+ sigGetSegForLostIDsInputImg = Signal(str)
sigSegForLostIDsWorkerAskInstallGPU = Signal(str, bool)
sigTrackManuallyAddedObject = Signal(object, object, bool, bool)
@@ -217,6 +213,7 @@ def __init__(self, guiWin, mutex, waitCond, debug=False):
self.mutex = mutex
self.waitCond = waitCond
self._debug = debug
+ self.inputImgForSegForLostIDs = None
def emitSigAskInit(self):
self.mutex.lock()
@@ -224,35 +221,26 @@ def emitSigAskInit(self):
self.waitCond.wait(self.mutex)
self.mutex.unlock()
- def emitSigShowImageDebug(self, img):
- # self.mutex.lock()
- self.sigshowImageDebug.emit(img)
- # self.waitCond.wait(self.mutex)
- # self.mutex.unlock()
-
def emitSigStoreData(self, autosave):
self.mutex.lock()
self.sigStoreData.emit(autosave)
self.waitCond.wait(self.mutex)
self.mutex.unlock()
- def emitSigUpdateRP(self, wl_track_og_curr, wl_update):
+ def emitSigUpdateRP(self, wl_update, wl_track_og_curr):
self.mutex.lock()
- self.sigUpdateRP.emit(wl_track_og_curr, wl_update)
+ self.sigUpdateRP.emit(wl_update, wl_track_og_curr)
self.waitCond.wait(self.mutex)
self.mutex.unlock()
- # def emitSigGetData(self):
- # self.mutex.lock()
- # self.sigGetData.emit()
- # self.waitCond.wait(self.mutex)
- # self.mutex.unlock()
-
- def emitSigAskInstallModel(self, model_name):
+ def emitGetSegForLostIDsInputImg(self, image_channel_name):
self.mutex.lock()
- self.sigAskInstallModel.emit(model_name)
+ self.sigGetSegForLostIDsInputImg.emit(image_channel_name)
self.waitCond.wait(self.mutex)
+ img = self.inputImgForSegForLostIDs
+ self.inputImgForSegForLostIDs = None
self.mutex.unlock()
+ return img
def emitSigAskInstallGPU(self, base_model_name, use_gpu):
self.mutex.lock()
@@ -261,24 +249,6 @@ def emitSigAskInstallGPU(self, base_model_name, use_gpu):
self.waitCond.wait(self.mutex)
self.mutex.unlock()
- # def emitGet2Dlab(self):
- # self.mutex.lock()
- # self.sigGet2Dlab.emit()
- # self.waitCond.wait(self.mutex)
- # self.mutex.unlock()
-
- # def emitGetTrackedLostIDs(self):
- # self.mutex.lock()
- # self.sigGetTrackedLostIDs.emit()
- # self.waitCond.wait(self.mutex)
- # self.mutex.unlock()
-
- # def emitGetBrushID(self):
- # self.mutex.lock()
- # self.sigGetBrushID.emit()
- # self.waitCond.wait(self.mutex)
- # self.mutex.unlock()
-
def emitTrackManuallyAddedObject(self, IDs, isLost, wl_update, wl_track_og_curr):
self.mutex.lock()
self.sigTrackManuallyAddedObject.emit(IDs, isLost, wl_update, wl_track_og_curr)
@@ -298,36 +268,69 @@ def run(self):
return
self.logger.info('Segmentation for lost IDs started.')
- model_name = 'local_seg'
- base_model_name = self.guiWin.SegForLostIDsSettings['base_model_name']
- idx = self.guiWin.modelNames.index(model_name)
- acdcSegment = self.guiWin.acdcSegment_li[idx]
-
- init_kwargs = self.guiWin.SegForLostIDsSettings['win'].init_kwargs
-
- use_gpu = init_kwargs.get('device_type', 'cpu') != 'cpu'
- use_gpu = use_gpu or init_kwargs.get('use_gpu', False)
-
- self.emitSigAskInstallGPU(base_model_name, use_gpu)
+
+ model_settings = self.guiWin.SegForLostIDsSettings['models_settings']
+
+ n_models = len(model_settings)
+ total_steps = 2 * n_models
+ self.signals.initProgressBar.emit(total_steps)
- if not self.gpu_go:
- self.signals.finished.emit(self)
- return
+ assigned_IDs = []
+ missing_IDs_global = set()
+ original_lab = posData.lab.copy()
+ IDs_bboxs_list = []
+ bboxs_list = []
+ new_labs = []
- if not self.dont_force_cpu:
- if 'device' in init_kwargs:
- init_kwargs['device'] = 'cpu'
- if 'use_gpu' in init_kwargs:
- init_kwargs['use_gpu'] = False
+ prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels'])
+ prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set
+
+ # iteratively go through models, keeping found labs and only feeding remaining lost IDs to the next model
+ for model_idx, model_settings_i in enumerate(model_settings):
+ base_model_name = model_settings_i['base_model_name']
+ init_kwargs_new = dict(model_settings_i['init_kwargs_new'])
+ image_channel_name = init_kwargs_new.pop(
+ 'image_channel_name', 'Displayed image'
+ )
+ args_new = model_settings_i['args_new']
+ init_kwargs = model_settings_i.get('init_kwargs', {})
+ model_kwargs = model_settings_i.get('model_kwargs', {})
+ preproc_recipe = model_settings_i.get('preproc_recipe', None)
+ applyPostProcessing = model_settings_i.get('applyPostProcessing', False)
+ standardPostProcessKwargs = model_settings_i.get('standardPostProcessKwargs', {})
+ customPostProcessFeatures = model_settings_i.get('customPostProcessFeatures', None)
+ customPostProcessGroupedFeatures = model_settings_i.get('customPostProcessGroupedFeatures', None)
+
+ # Fall back to reading from the live win object when available
+ win = model_settings_i.get('win')
+ if win is not None:
+ init_kwargs = win.init_kwargs
+ model_kwargs = win.model_kwargs
+ preproc_recipe = win.preproc_recipe
+ applyPostProcessing = win.applyPostProcessing
+ standardPostProcessKwargs = win.standardPostProcessKwargs
+ customPostProcessFeatures = win.customPostProcessFeatures
+ customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures
+
+ use_gpu = init_kwargs.get('device_type', 'cpu').lower() != 'cpu'
+ use_gpu = use_gpu or init_kwargs.get('use_gpu', False)
+
+ self.emitSigAskInstallGPU(base_model_name, use_gpu)
+
+ if not self.gpu_go:
+ self.signals.finished.emit(self)
+ return
+
+ if not self.dont_force_cpu:
+ if 'device' in init_kwargs:
+ init_kwargs_new = dict(init_kwargs_new, device='cpu')
+ if 'use_gpu' in init_kwargs:
+ init_kwargs_new = dict(init_kwargs_new, use_gpu=False)
- if acdcSegment is None or base_model_name != self.guiWin.local_seg_base_model_name:
try:
self.logger.info(f'Importing {base_model_name}...')
- self.emitSigAskInstallModel(base_model_name)
acdcSegment = myutils.import_segment_module(base_model_name)
- self.guiWin.acdcSegment_li[idx] = acdcSegment
- self.guiWin.local_seg_base_model_name = base_model_name
- except (IndexError, ImportError, KeyError) as e:
+ except (IndexError, ImportError, KeyError):
self.logger.warning(
f'Cannot import {base_model_name} model. '
'Please install it first.'
@@ -339,134 +342,165 @@ def run(self):
self.signals.finished.emit(self)
return
- win = self.guiWin.SegForLostIDsSettings['win']
- init_kwargs_new = self.guiWin.SegForLostIDsSettings['init_kwargs_new']
- args_new = self.guiWin.SegForLostIDsSettings['args_new']
-
- model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new)
- if model is None:
- self.logger.info('Segmentation model was not initialized correctly!')
- self.signals.critical.emit(
- (self, 'Segmentation model was not initialized correctly!')
- )
- self.signals.finished.emit(self)
- return
- if self._debug:
- try:
- model.setupLogger(self.guiwin.logger)
- except Exception as e:
- pass
-
- assigned_IDs = []
- missing_IDs_global = set()
- original_lab = posData.lab.copy()
- IDs_bboxs_list = []
- bboxs_list = []
+ model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new)
+ if model is None:
+ self.logger.info('Segmentation model was not initialized correctly!')
+ self.signals.critical.emit(
+ (self, 'Segmentation model was not initialized correctly!')
+ )
+ self.signals.finished.emit(self)
+ return
+ if self._debug:
+ try:
+ model.setupLogger(self.guiWin.logger)
+ except Exception:
+ pass
- curr_img = self.guiWin.getDisplayedImg1()
- prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels'])
- prev_IDs = set(posData.allData_li[frame_i-1]['IDs'])
+ curr_img = self.emitGetSegForLostIDsInputImg(image_channel_name)
+ if curr_img is None:
+ self.signals.critical.emit(
+ (self, 'Could not get input image for SegForLostIDsWorker')
+ )
+ self.signals.finished.emit(self)
+ return
- # should probably not paly so much with posData.lab, instead handle stuff myself
- self.signals.initProgressBar.emit(2 * args_new['max_iterations'])
- new_labs = np.zeros([args_new['max_iterations'], *posData.lab.shape], dtype=np.uint32)
- for i in range(args_new['max_iterations']):
curr_lab = self.guiWin.get_2Dlab(posData.lab)
tracked_lost_IDs = self.guiWin.getTrackedLostIDs()
new_unique_ID = self.guiWin.setBrushID(useCurrentLab=True, return_val=True)
- missing_IDs = prev_IDs - set(posData.IDs) - set(tracked_lost_IDs)
+ missing_IDs = prev_IDs - posData.rp.IDs_set - set(tracked_lost_IDs)
missing_IDs_global.update(missing_IDs)
assigned_IDs_prev = assigned_IDs.copy()
out = segm_utils.single_cell_seg(
- model, prev_lab, curr_lab, curr_img,
+ model, prev_lab, curr_lab, curr_img,
missing_IDs, new_unique_ID,
- win, posData,
+ posData,
distance_filler_growth=args_new['distance_filler_growth'],
overlap_threshold=args_new['overlap_threshold'],
padding=args_new['padding'],
+ model_kwargs=model_kwargs,
+ preproc_recipe=preproc_recipe,
+ applyPostProcessing=applyPostProcessing,
+ standardPostProcessKwargs=standardPostProcessKwargs,
+ customPostProcessFeatures=customPostProcessFeatures,
+ customPostProcessGroupedFeatures=customPostProcessGroupedFeatures,
+ debug=self._debug
)
- new_lab, assigned_IDs, IDs_bboxs, bboxs = out
-
+ if self._debug:
+ new_lab, assigned_IDs, IDs_bboxs, bboxs, imgs_to_show = out
+ else:
+ new_lab, assigned_IDs, IDs_bboxs, bboxs = out
+
IDs_bboxs_list.append(IDs_bboxs)
bboxs_list.append(bboxs)
posData.lab = new_lab
self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False)
newly_assigned_IDs = set(assigned_IDs) - set(assigned_IDs_prev)
self.emitTrackManuallyAddedObject(newly_assigned_IDs, True, False, False)
- new_labs[i] = posData.lab.copy()
+ new_labs.append(posData.lab.copy())
self.signals.progressBar.emit(1)
- if self._debug:
- originals = []
- models = []
-
- posData.lab = original_lab.copy()
-
- global_area_mean = np.mean([obj.area for obj in posData.rp])
- for IDs_bboxs, bboxs in zip(IDs_bboxs_list, bboxs_list):
- model_lab = new_labs[i]
if self._debug:
- originals.append(original_lab.copy())
- models.append(posData.lab.copy())
-
- for IDs, bbox in zip(IDs_bboxs, bboxs):
+ print(f'Model {model_idx}:')
+ print('Displaying curr_img and curr_lab:')
+ display_info = {
+ 'title': f'Model {model_idx}, input image and lab',
+ 'images': [curr_img, curr_lab],
+ 'img_titles': ['curr_img', 'curr_lab']
+ }
+ self.sigshowImageDebug.emit(display_info)
+ print('box_curr_img, box_curr_lab, box_curr_lab_other_IDs_grown, box_curr_img (after filling), box_model_lab')
+ for i, imgs in imgs_to_show.items():
+ display_info = {
+ 'title': f'Model {model_idx}, bbox {i}',
+ 'images': imgs,
+ 'img_titles': [
+ 'box_curr_img', 'box_curr_lab', 'box_curr_lab_other_IDs_grown',
+ 'box_curr_img (after filling)', 'box_model_lab'
+ ]
+ }
+ self.sigshowImageDebug.emit(display_info)
+ global_areas = [obj.area for obj in posData.rp]
+ global_area_mean = np.mean(global_areas) if len(global_areas) > 0 else None
+ for i, (IDs_bboxs, bboxs) in enumerate(zip(IDs_bboxs_list, bboxs_list)):
+ args_new = model_settings[i]['args_new']
+ model_lab = new_labs[i]
+
+ for j, (IDs, bbox) in enumerate(zip(IDs_bboxs, bboxs)):
+ box_x_min, box_x_max, box_y_min, box_y_max = bbox
+
+ model_bbox_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max]
+ model_bbox_lab = skimage.segmentation.clear_border(model_bbox_lab, buffer_size=1)
+ model_lab_rp = regionprops.acdcRegionprops(model_bbox_lab, precache_centroids=False)
+
+ if self._debug:
+ IDs_filtered_border = [
+ ID for ID in IDs
+ if ID not in model_lab_rp.IDs_set
+ ]
- box_x_min, box_x_max, box_y_min, box_y_max = bbox
original_bbox_lab = original_lab[box_x_min:box_x_max, box_y_min:box_y_max]
original_bbox_lab_cleared_borders = skimage.segmentation.clear_border(original_bbox_lab)
- box_model_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max]
+ original_lab_rp = regionprops.acdcRegionprops(original_bbox_lab_cleared_borders, precache_centroids=False)
- # original_bbox_lab[np.isin(original_bbox_lab, IDs)] = 0 should be a given. If not seg for lost IDs this recommended
-
- box_model_lab = skimage.segmentation.clear_border(box_model_lab, buffer_size=1)
-
- rp_model_lab = skimage.measure.regionprops(box_model_lab)
- rp_original_lab = skimage.measure.regionprops(original_bbox_lab)
- rp_original_lab_cleared = skimage.measure.regionprops(original_bbox_lab_cleared_borders)
-
- original_IDs = [obj.label for obj in rp_original_lab]
- areas = [obj.area for obj in rp_original_lab_cleared]
+ areas = [obj.area for obj in original_lab_rp]
if len(areas) > 0:
area_mean = np.mean(areas)
- else:
+ elif global_area_mean is not None:
area_mean = global_area_mean
+ else:
+ model_areas = [obj.area for obj in model_lab_rp]
+ area_mean = np.mean(model_areas) if len(model_areas) > 0 else None
+
+ skip_size_filter = area_mean is None
+ if not skip_size_filter:
+ min_area = (1 - args_new['size_perc_diff']) * area_mean
+ max_area = (1 + args_new['size_perc_diff']) * area_mean
+
if args_new['allow_only_tracked_cells']:
- filtered_IDs = [obj.label for obj in rp_model_lab
- if obj.area > (1 - args_new['size_perc_diff']) * area_mean
- and obj.area < (1 + args_new['size_perc_diff']) * area_mean
- and obj.label not in original_IDs
- and obj.label in missing_IDs_global]
+ filtered_IDs = [
+ obj.label for obj in model_lab_rp
+ if (skip_size_filter or (obj.area > min_area and obj.area < max_area))
+ and obj.label in prev_IDs # only keep objects that have ID already in the previous frame
+ ]
+
else:
- filtered_IDs = [obj.label for obj in rp_model_lab
- if obj.area > (1 - args_new['size_perc_diff']) * area_mean
- and obj.area < (1 + args_new['size_perc_diff']) * area_mean
- and obj.label not in original_IDs]
-
- if self._debug or DEBUG:
- filtered_sizes = [(obj.label, obj.area) for obj in rp_model_lab if obj.label in filtered_IDs]
- self.logger.info(f"Filtered sizes: {filtered_sizes}")
- for label in filtered_IDs:
- original_bbox_lab[box_model_lab == label] = label # here the stuff should be tracked, so we keep the ID!
-
- # original_lab[box_x_min:box_x_max, box_y_min:box_y_max] = original_bbox_lab
-
- self.signals.progressBar.emit(1)
-
- posData.lab = original_lab
+ filtered_IDs = [
+ obj.label for obj in model_lab_rp
+ if (skip_size_filter or (obj.area > min_area and obj.area < max_area))
+ ]
+
+ if self._debug:
+ if args_new['allow_only_tracked_cells']:
+ IDs_filtered_for_size = [
+ obj.label for obj in model_lab_rp
+ if (skip_size_filter or (obj.area > min_area and obj.area < max_area))
+ ]
+ else:
+ IDs_filtered_for_size = []
+ IDs_filtered_for_tracking = [
+ obj.label for obj in model_lab_rp
+ if obj.label in prev_IDs
+ ]
+ print(f'Model {i}, bbox {j}:')
+ print(f' Start: {[obj.label for obj in model_lab_rp]}')
+ print(f' Size: {IDs_filtered_for_size}')
+ print(f' Tracking: {IDs_filtered_for_tracking}')
+ print(f' Border: {IDs_filtered_border}')
+
+ original_bbox_lab[
+ np.isin(model_bbox_lab, filtered_IDs)
+ ] = model_bbox_lab[np.isin(model_bbox_lab, filtered_IDs)]
- # if self._debug:
- # originals = np.concatenate(originals, axis=0)
- # models = np.concatenate(models, axis=0)
- # self.emitSigShowImageDebug(originals)
- # self.emitSigShowImageDebug(models)
+ self.signals.progressBar.emit(1)
- self.emitSigUpdateRP(wl_track_og_curr=True, wl_update=True)
+ posData.lab = original_lab
+ self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False)
self.emitSigStoreData(autosave=True)
self.logger.info('Segmentation for lost IDs done.')
-
+
self.signals.finished.emit(self)
class AlignDataWorker(QObject):
@@ -5185,7 +5219,7 @@ def check(self, posData):
# There are no annotations at frame_i --> stop
break
- IDs = data_dict['IDs']
+ IDs = data_dict['regionprops'].IDs
checker = core.CcaIntegrityChecker(cca_df, lab, IDs)
for checkpoint in checkpoints:
@@ -6220,7 +6254,7 @@ def saveAcdcDf(self, posData: load.loadData, end_i):
last_cca_frame_i=self.mainWin.save_cca_until_frame_i
)
- def saveSegmData(self, posData, end_i, saved_segm_data):
+ def saveSegmData(self, posData: load.loadData, end_i, saved_segm_data):
self.progress.emit(f'Saving segmentation data for {posData.relPath}...')
@@ -6263,6 +6297,13 @@ def saveSegmData(self, posData, end_i, saved_segm_data):
io.savez_compressed(
posData.segm_npz_path, np.squeeze(saved_segm_data)
)
+
+ # save information about the segmention
+ posData.updateSegmMetadata(all=True)
+ posData.saveSegmMetadataIni()
+
+ # save rp info about segm
+ self.progress.emit(f'Saving additional data for {posData.relPath}...')
posData.segm_data = saved_segm_data
# Allow single 2D/3D image
if posData.SizeT == 1:
diff --git a/precompile_functions.py b/precompile_functions.py
new file mode 100644
index 000000000..ab9abd49b
--- /dev/null
+++ b/precompile_functions.py
@@ -0,0 +1,28 @@
+# only needed for cython extensions, not needed to run normally
+import sys
+from setuptools import setup, Extension
+from Cython.Build import cythonize
+import numpy as np
+
+setup(
+ ext_modules=cythonize(
+ Extension(
+ "cellacdc.precompiled.precompiled_functions",
+ sources=["cellacdc/precompiled_functions.pyx"],
+ include_dirs=[np.get_include()],
+ ),
+ annotate=True,
+ build_dir="build/cython", # .c and .html files go here
+ )
+)
+# # move compiled binary to precompiled/
+# import shutil
+# import os
+
+# src_dir = "cellacdc"
+# for filename in os.listdir(src_dir):
+# if filename.startswith("precompiled_functions") and (filename.endswith(".so") or filename.endswith(".pyd")):
+# target_path = os.path.join("cellacdc", "precompiled", filename)
+# shutil.move(os.path.join(src_dir, filename), target_path)
+# print(f"Moved {filename} to {target_path}")
+
diff --git a/pyproject.toml b/pyproject.toml
index e96d7a741..14799461b 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,7 +2,9 @@
requires = [
"setuptools>=64",
"wheel",
- "setuptools_scm[toml]>=8"
+ "setuptools_scm[toml]>=8",
+ "cython",
+ "numpy",
]
build-backend = "setuptools.build_meta"
diff --git a/tests/test_cellacdc_tracker_specific_ids.py b/tests/test_cellacdc_tracker_specific_ids.py
new file mode 100644
index 000000000..130e1f8bb
--- /dev/null
+++ b/tests/test_cellacdc_tracker_specific_ids.py
@@ -0,0 +1,160 @@
+import numpy as np
+from skimage.measure import regionprops
+
+from cellacdc.trackers.CellACDC import CellACDC_tracker
+from cellacdc.trackers.CellACDC_2steps.CellACDC_2steps_tracker import tracker as TwoStepsTracker
+from cellacdc.trackers.CellACDC_normal_division.CellACDC_normal_division_tracker import tracker as NormalDivisionTracker
+
+
+def test_track_frame_specific_ids_only_tracks_requested_current_ids():
+ prev_lab = np.array(
+ [
+ [1, 1, 0, 5, 5],
+ [1, 1, 0, 5, 5],
+ ],
+ dtype=np.uint16,
+ )
+ lab = np.array(
+ [
+ [7, 7, 0, 5, 5],
+ [7, 7, 0, 5, 5],
+ ],
+ dtype=np.uint16,
+ )
+
+ tracked_lab, assignments = CellACDC_tracker.track_frame(
+ prev_lab,
+ regionprops(prev_lab),
+ lab,
+ regionprops(lab),
+ IDs_curr_untracked=[7, 5],
+ unique_ID=10,
+ assign_unique_new_IDs=True,
+ return_assignments=True,
+ specific_IDs=[5],
+ )
+
+ np.testing.assert_array_equal(tracked_lab, lab)
+ assert assignments['assignments'] == {}
+
+
+def test_track_frame_specific_ids_skips_merging_with_unrelated_current_labels():
+ prev_lab = np.array(
+ [
+ [5, 5, 0, 0],
+ [5, 5, 0, 0],
+ ],
+ dtype=np.uint16,
+ )
+ lab = np.array(
+ [
+ [7, 7, 0, 5],
+ [7, 7, 0, 5],
+ ],
+ dtype=np.uint16,
+ )
+
+ tracked_lab, add_info = CellACDC_tracker.track_frame(
+ prev_lab,
+ regionprops(prev_lab),
+ lab,
+ regionprops(lab),
+ IDs_curr_untracked=[7, 5],
+ unique_ID=10,
+ assign_unique_new_IDs=True,
+ return_assignments=True,
+ specific_IDs=[7],
+ )
+
+ expected = np.array(
+ [
+ [10, 10, 0, 5],
+ [10, 10, 0, 5],
+ ],
+ dtype=np.uint16,
+ )
+
+ np.testing.assert_array_equal(tracked_lab, expected)
+ assert add_info['assignments'] == {7: 10}
+
+
+def test_two_steps_specific_ids_can_match_selected_new_object_to_lost_previous_id():
+ prev_lab = np.array(
+ [
+ [5, 5, 0, 0],
+ [5, 5, 0, 0],
+ ],
+ dtype=np.uint16,
+ )
+ lab = np.array(
+ [
+ [7, 7, 0, 0],
+ [7, 7, 0, 0],
+ ],
+ dtype=np.uint16,
+ )
+
+ tracked_lab, add_info = TwoStepsTracker(
+ annotate_objects_tracked_second_step=False
+ ).track_frame(
+ prev_lab,
+ lab,
+ overlap_threshold=0.4,
+ lost_IDs_search_range=10,
+ unique_ID=10,
+ return_assignments=True,
+ specific_IDs=[7],
+ )
+
+ expected = np.array(
+ [
+ [5, 5, 0, 0],
+ [5, 5, 0, 0],
+ ],
+ dtype=np.uint16,
+ )
+
+ np.testing.assert_array_equal(tracked_lab, expected)
+ assert add_info['assignments'] == {7: 5}
+
+
+def test_normal_division_specific_ids_preserve_division_context():
+ prev_lab = np.array(
+ [
+ [5, 5, 5, 5],
+ [5, 5, 5, 5],
+ ],
+ dtype=np.uint16,
+ )
+ lab = np.array(
+ [
+ [7, 7, 8, 8],
+ [7, 7, 8, 8],
+ ],
+ dtype=np.uint16,
+ )
+
+ tracked_lab, add_info = NormalDivisionTracker().track_frame(
+ prev_lab,
+ lab,
+ IoA_thresh=0.8,
+ IoA_thresh_daughter=0.25,
+ IoA_thresh_aggressive=0.5,
+ min_daughter=2,
+ max_daughter=2,
+ unique_ID=20,
+ return_assignments=True,
+ specific_IDs=[7],
+ )
+
+ expected = np.array(
+ [
+ [20, 20, 8, 8],
+ [20, 20, 8, 8],
+ ],
+ dtype=np.uint16,
+ )
+
+ np.testing.assert_array_equal(tracked_lab, expected)
+ assert add_info['mothers'] == {5}
+ assert add_info['assignments'] == {7: 20}
\ No newline at end of file
diff --git a/tests/test_precompiled_functions.py b/tests/test_precompiled_functions.py
new file mode 100644
index 000000000..0515f6ce4
--- /dev/null
+++ b/tests/test_precompiled_functions.py
@@ -0,0 +1,520 @@
+"""Tests for Cython functions in cellacdc/precompiled_functions.pyx.
+
+Each Cython function is validated against a pure-Python / skimage reference
+implementation on realistic synthetic label images built from filled discs
+(2-D) and filled spheres (3-D).
+
+Run with:
+ pytest tests/test_precompiled_functions.py -v
+"""
+
+import pytest
+import numpy as np
+from skimage.draw import disk, ellipsoid
+from skimage.measure import regionprops
+
+# ---------------------------------------------------------------------------
+# Skip the whole module when the Cython extension is not compiled yet
+# ---------------------------------------------------------------------------
+pytest.importorskip(
+ "cellacdc.precompiled.precompiled_functions",
+ reason="Cython extension not compiled; run: python precompile_functions.py build_ext --inplace",
+)
+from cellacdc.precompiled.precompiled_functions import (
+ find_all_objects_2D,
+ find_all_objects_3D,
+ most_common_projection_3D,
+ object_projections_and_size_3D,
+ object_projection_and_size_3D,
+ calc_IoA_matrix_2D,
+ calc_IoA_matrix_3D,
+)
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_disc_label_image(shape, specs):
+ """Build a 2-D uint32 label image from a list of (label, cy, cx, radius)."""
+ img = np.zeros(shape, dtype=np.uint32)
+ for label, cy, cx, r in specs:
+ rr, cc = disk((cy, cx), r, shape=shape)
+ img[rr, cc] = label
+ return img
+
+
+def _make_sphere_label_image(shape, specs):
+ """Build a 3-D uint32 label image from (label, cz, cy, cx, rz, ry, rx)."""
+ img = np.zeros(shape, dtype=np.uint32)
+ for label, cz, cy, cx, rz, ry, rx in specs:
+ sph = ellipsoid(rz, ry, rx)
+ sz, sy, sx = sph.shape
+ z0 = cz - sz // 2
+ y0 = cy - sy // 2
+ x0 = cx - sx // 2
+ for dz in range(sz):
+ for dy in range(sy):
+ for dx in range(sx):
+ if not sph[dz, dy, dx]:
+ continue
+ zi, yi, xi = z0 + dz, y0 + dy, x0 + dx
+ if 0 <= zi < shape[0] and 0 <= yi < shape[1] and 0 <= xi < shape[2]:
+ img[zi, yi, xi] = label
+ return img
+
+
+def _reference_bboxes_2D(label_img):
+ """skimage regionprops bounding boxes as {label: (r0, r1, c0, c1)}."""
+ result = {}
+ for obj in regionprops(label_img.astype(np.int32)):
+ r0, c0, r1, c1 = obj.bbox
+ result[obj.label] = (r0, r1, c0, c1)
+ return result
+
+
+def _reference_bboxes_3D(label_img):
+ """skimage regionprops bounding boxes as {label: (z0, z1, r0, r1, c0, c1)}."""
+ result = {}
+ for obj in regionprops(label_img.astype(np.int32)):
+ z0, r0, c0, z1, r1, c1 = obj.bbox
+ result[obj.label] = (z0, z1, r0, r1, c0, c1)
+ return result
+
+
+def _python_ioa_matrix(lab, prev_lab, rp, prev_rp, use_union):
+ """Pure-Python IoA matrix (the original fallback in CellACDC_tracker)."""
+ IDs_curr = [obj.label for obj in rp]
+ IDs_prev = [obj.label for obj in prev_rp]
+ IoA = np.zeros((len(IDs_curr), len(IDs_prev)), dtype=np.float64)
+ rp_mapper = {obj.label: obj for obj in rp}
+ idx_curr = {ID: i for i, ID in enumerate(IDs_curr)}
+ for j, obj_prev in enumerate(prev_rp):
+ if use_union:
+ pass # denom computed per overlap
+ else:
+ denom_val = obj_prev.area
+ intersect_IDs, intersects = np.unique(
+ lab[obj_prev.slice][obj_prev.image], return_counts=True
+ )
+ for intersect_ID, I in zip(intersect_IDs, intersects):
+ if intersect_ID == 0 or I == 0:
+ continue
+ if use_union:
+ if intersect_ID not in rp_mapper:
+ continue
+ obj_curr = rp_mapper[intersect_ID]
+ denom_val = obj_prev.area + obj_curr.area - I
+ if denom_val == 0:
+ continue
+ idx = idx_curr.get(intersect_ID)
+ if idx is None:
+ continue
+ IoA[idx, j] = I / denom_val
+ return IoA, IDs_curr, IDs_prev
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+DISC_SPECS = [
+ # (label, cy, cx, radius)
+ (1, 30, 30, 20), # large disc, centre-left
+ (2, 30, 90, 15), # medium disc, centre-right
+ (3, 80, 60, 10), # smaller disc, bottom
+ (4, 80, 110, 12), # slightly overlapping with disc 3 at boundary
+]
+
+DISC_SPECS_SHIFTED = [
+ # same cells shifted a few pixels to simulate motion
+ (1, 32, 32, 20),
+ (2, 28, 92, 15),
+ (3, 82, 62, 10),
+ (4, 78, 112, 12),
+]
+
+SPHERE_SPECS = [
+ # (label, cz, cy, cx, rz, ry, rx)
+ (1, 10, 20, 20, 5, 8, 8),
+ (2, 10, 20, 50, 4, 6, 6),
+ (3, 10, 45, 35, 3, 5, 5),
+]
+
+SPHERE_SPECS_SHIFTED = [
+ (1, 11, 21, 21, 5, 8, 8),
+ (2, 9, 19, 52, 4, 6, 6),
+ (3, 10, 46, 34, 3, 5, 5),
+]
+
+
+@pytest.fixture(scope="module")
+def label_2d():
+ return _make_disc_label_image((128, 128), DISC_SPECS)
+
+
+@pytest.fixture(scope="module")
+def label_2d_shifted():
+ return _make_disc_label_image((128, 128), DISC_SPECS_SHIFTED)
+
+
+@pytest.fixture(scope="module")
+def label_3d():
+ return _make_sphere_label_image((20, 64, 64), SPHERE_SPECS)
+
+
+@pytest.fixture(scope="module")
+def label_3d_shifted():
+ return _make_sphere_label_image((20, 64, 64), SPHERE_SPECS_SHIFTED)
+
+
+# ---------------------------------------------------------------------------
+# find_all_objects_2D
+# ---------------------------------------------------------------------------
+
+class TestFindAllObjects2D:
+ def test_labels_match(self, label_2d):
+ labels, _ = find_all_objects_2D(label_2d)
+ assert set(labels.tolist()) == {1, 2, 3, 4}
+
+ def test_bbox_matches_skimage(self, label_2d):
+ labels, bboxes = find_all_objects_2D(label_2d)
+ ref = _reference_bboxes_2D(label_2d)
+ for lbl, bbox in zip(labels.tolist(), bboxes.tolist()):
+ r0, r1, c0, c1 = bbox
+ assert (r0, r1, c0, c1) == ref[lbl], (
+ f"Label {lbl}: got {(r0,r1,c0,c1)}, expected {ref[lbl]}"
+ )
+
+ def test_empty_image_returns_empty(self):
+ empty = np.zeros((64, 64), dtype=np.uint32)
+ result = find_all_objects_2D(empty)
+ assert result == ([], [])
+
+ def test_single_pixel_object(self):
+ img = np.zeros((10, 10), dtype=np.uint32)
+ img[5, 7] = 1
+ labels, bboxes = find_all_objects_2D(img)
+ assert labels[0] == 1
+ assert list(bboxes[0]) == [5, 6, 7, 8]
+
+ def test_label_above_300_triggers_growth(self):
+ """Label > 300 must still produce the correct bounding box."""
+ img = np.zeros((64, 64), dtype=np.uint32)
+ rr, cc = disk((32, 32), 10, shape=img.shape)
+ img[rr, cc] = 350
+ labels, bboxes = find_all_objects_2D(img)
+ ref = _reference_bboxes_2D(img)
+ r0, r1, c0, c1 = bboxes[0].tolist()
+ assert (r0, r1, c0, c1) == ref[350]
+
+ def test_label_above_600_triggers_second_growth(self):
+ img = np.zeros((64, 64), dtype=np.uint32)
+ rr, cc = disk((32, 32), 10, shape=img.shape)
+ img[rr, cc] = 650
+ labels, bboxes = find_all_objects_2D(img)
+ ref = _reference_bboxes_2D(img)
+ r0, r1, c0, c1 = bboxes[0].tolist()
+ assert (r0, r1, c0, c1) == ref[650]
+
+ def test_multiple_labels_across_300_boundary(self):
+ img = np.zeros((64, 64), dtype=np.uint32)
+ rr1, cc1 = disk((20, 20), 8, shape=img.shape)
+ rr2, cc2 = disk((20, 50), 8, shape=img.shape)
+ img[rr1, cc1] = 1
+ img[rr2, cc2] = 301
+ labels, bboxes = find_all_objects_2D(img)
+ ref = _reference_bboxes_2D(img)
+ for lbl, bbox in zip(labels.tolist(), bboxes.tolist()):
+ assert tuple(bbox) == ref[lbl]
+
+ def test_bbox_dtype_is_uint32(self, label_2d):
+ _, bboxes = find_all_objects_2D(label_2d)
+ assert bboxes.dtype == np.uint32
+
+
+# ---------------------------------------------------------------------------
+# find_all_objects_3D
+# ---------------------------------------------------------------------------
+
+class TestFindAllObjects3D:
+ def test_labels_match(self, label_3d):
+ labels, _ = find_all_objects_3D(label_3d)
+ assert set(labels.tolist()) == {1, 2, 3}
+
+ def test_bbox_matches_skimage(self, label_3d):
+ labels, bboxes = find_all_objects_3D(label_3d)
+ ref = _reference_bboxes_3D(label_3d)
+ for lbl, bbox in zip(labels.tolist(), bboxes.tolist()):
+ z0, z1, r0, r1, c0, c1 = bbox
+ assert (z0, z1, r0, r1, c0, c1) == ref[lbl], (
+ f"Label {lbl}: got {(z0,z1,r0,r1,c0,c1)}, expected {ref[lbl]}"
+ )
+
+ def test_empty_image_returns_empty(self):
+ empty = np.zeros((8, 16, 16), dtype=np.uint32)
+ result = find_all_objects_3D(empty)
+ assert result == ([], [])
+
+ def test_single_voxel_object(self):
+ img = np.zeros((8, 8, 8), dtype=np.uint32)
+ img[3, 4, 5] = 2
+ labels, bboxes = find_all_objects_3D(img)
+ assert labels[0] == 2
+ assert list(bboxes[0]) == [3, 4, 4, 5, 5, 6]
+
+ def test_bbox_dtype_is_uint32(self, label_3d):
+ _, bboxes = find_all_objects_3D(label_3d)
+ assert bboxes.dtype == np.uint32
+
+
+# ---------------------------------------------------------------------------
+# most_common_projection_3D
+# ---------------------------------------------------------------------------
+
+class TestMostCommonProjection3D:
+ @pytest.mark.parametrize(
+ "axis, expected",
+ [
+ (
+ 0,
+ np.array(
+ [
+ [1, 1],
+ [1, 2],
+ [2, 2],
+ [2, 2],
+ [1, 0],
+ [1, 0],
+ ],
+ dtype=np.uint32,
+ ),
+ ),
+ (1, np.array([[1, 2]], dtype=np.uint32)),
+ (2, np.array([[1, 1, 2, 2, 1, 1]], dtype=np.uint32)),
+ ],
+ )
+ def test_counts_across_full_axis_not_runs(self, axis, expected):
+ """A label split into multiple runs must still be counted globally."""
+ lab = np.array(
+ [
+ [
+ [1, 1],
+ [1, 2],
+ [2, 2],
+ [2, 2],
+ [1, 0],
+ [1, 0],
+ ]
+ ],
+ dtype=np.uint32,
+ )
+ out = most_common_projection_3D(lab, axis)
+ assert out.shape == expected.shape
+ np.testing.assert_array_equal(out, expected)
+
+ @pytest.mark.parametrize("axis", [0, 1, 2])
+ def test_ignores_zero_but_returns_zero_if_no_nonzero(self, axis):
+ lab = np.zeros((3, 4, 5), dtype=np.uint32)
+ out = most_common_projection_3D(lab, axis)
+ assert out.dtype == np.uint32
+ np.testing.assert_array_equal(out, np.zeros_like(out, dtype=np.uint32))
+
+ def test_invalid_axis_raises(self):
+ lab = np.zeros((2, 2, 2), dtype=np.uint32)
+ with pytest.raises(ValueError):
+ most_common_projection_3D(lab, 3)
+
+
+# ---------------------------------------------------------------------------
+# object_projections_and_size_3D
+# ---------------------------------------------------------------------------
+
+class TestObjectProjectionsAndSize3D:
+ def test_matches_numpy_binary_projections_and_voxel_count_for_one_id(self):
+ cutout = np.zeros((4, 5, 6), dtype=np.uint32)
+ cutout[0, 1, 2] = 5
+ cutout[1, 1, 2] = 5
+ cutout[2, 3, 4] = 5
+ cutout[3, 0, 1] = 9
+ cutout[0, 0, 0] = 7
+
+ obj_mask = cutout == 5
+ proj_z, proj_y, proj_x, size = object_projections_and_size_3D(cutout, 5)
+
+ np.testing.assert_array_equal(proj_z, np.any(obj_mask, axis=0).astype(np.uint8))
+ np.testing.assert_array_equal(proj_y, np.any(obj_mask, axis=1).astype(np.uint8))
+ np.testing.assert_array_equal(proj_x, np.any(obj_mask, axis=2).astype(np.uint8))
+ assert size == int(np.count_nonzero(obj_mask))
+
+ def test_missing_id_returns_zero_projections_and_zero_size(self):
+ cutout = np.zeros((3, 4, 5), dtype=np.uint32)
+ cutout[1, 2, 3] = 4
+ proj_z, proj_y, proj_x, size = object_projections_and_size_3D(cutout, 8)
+
+ np.testing.assert_array_equal(proj_z, np.zeros((4, 5), dtype=np.uint8))
+ np.testing.assert_array_equal(proj_y, np.zeros((3, 5), dtype=np.uint8))
+ np.testing.assert_array_equal(proj_x, np.zeros((3, 4), dtype=np.uint8))
+ assert size == 0
+
+ def test_projection_dtype_is_uint8(self):
+ cutout = np.zeros((2, 3, 4), dtype=np.uint32)
+ cutout[1, 2, 3] = 7
+ proj_z, proj_y, proj_x, _ = object_projections_and_size_3D(cutout, 7)
+
+ assert proj_z.dtype == np.uint8
+ assert proj_y.dtype == np.uint8
+ assert proj_x.dtype == np.uint8
+
+
+class TestObjectProjectionAndSize3D:
+ @pytest.mark.parametrize("axis", [0, 1, 2])
+ def test_matches_numpy_projection_for_axis(self, axis):
+ cutout = np.zeros((4, 5, 6), dtype=np.uint32)
+ cutout[0, 1, 2] = 5
+ cutout[1, 1, 2] = 5
+ cutout[2, 3, 4] = 5
+ cutout[3, 0, 1] = 9
+
+ proj, size = object_projection_and_size_3D(cutout, 5, axis)
+
+ expected_mask = cutout == 5
+ expected_proj = np.any(expected_mask, axis=axis).astype(np.uint8)
+ np.testing.assert_array_equal(proj, expected_proj)
+ assert size == int(np.count_nonzero(expected_mask))
+
+ def test_invalid_axis_raises(self):
+ cutout = np.zeros((2, 3, 4), dtype=np.uint32)
+ with pytest.raises(ValueError):
+ object_projection_and_size_3D(cutout, 1, 3)
+
+
+# ---------------------------------------------------------------------------
+# calc_IoA_matrix_2D
+# ---------------------------------------------------------------------------
+
+def _run_cython_ioa_2d(lab, prev_lab, use_union):
+ rp = regionprops(lab.astype(np.int32))
+ prev_rp = regionprops(prev_lab.astype(np.int32))
+ curr_IDs_arr = np.array([obj.label for obj in rp], dtype=np.uint32)
+ prev_IDs_arr = np.array([obj.label for obj in prev_rp], dtype=np.uint32)
+ prev_areas_arr = np.array([obj.area for obj in prev_rp], dtype=np.uint32)
+ if use_union:
+ rp_mapper = {obj.label: obj for obj in rp}
+ curr_areas_arr = np.array(
+ [rp_mapper[ID].area for ID in curr_IDs_arr.tolist()], dtype=np.uint32
+ )
+ else:
+ curr_areas_arr = np.empty(0, dtype=np.uint32)
+ return (
+ calc_IoA_matrix_2D(
+ lab.astype(np.uint32), prev_lab.astype(np.uint32),
+ curr_IDs_arr, prev_IDs_arr,
+ prev_areas_arr, curr_areas_arr, use_union,
+ ),
+ rp, prev_rp,
+ )
+
+
+class TestCalcIoAMatrix2D:
+ @pytest.mark.parametrize("use_union", [False, True])
+ def test_matches_python_reference(self, label_2d, label_2d_shifted, use_union):
+ mat_cy, rp, prev_rp = _run_cython_ioa_2d(label_2d_shifted, label_2d, use_union)
+ mat_py, _, _ = _python_ioa_matrix(
+ label_2d_shifted, label_2d, rp, prev_rp, use_union
+ )
+ np.testing.assert_allclose(mat_cy, mat_py, rtol=1e-9, atol=1e-12,
+ err_msg=f"Mismatch for use_union={use_union}")
+
+ def test_shape(self, label_2d, label_2d_shifted):
+ mat_cy, rp, prev_rp = _run_cython_ioa_2d(label_2d_shifted, label_2d, False)
+ assert mat_cy.shape == (len(rp), len(prev_rp))
+
+ def test_values_between_0_and_1(self, label_2d, label_2d_shifted):
+ mat_cy, _, _ = _run_cython_ioa_2d(label_2d_shifted, label_2d, False)
+ assert np.all(mat_cy >= 0.0)
+ assert np.all(mat_cy <= 1.0 + 1e-12)
+
+ def test_no_overlap_gives_zero_matrix(self):
+ """Two images with disjoint objects should produce an all-zero IoA matrix."""
+ lab = _make_disc_label_image((64, 64), [(1, 10, 10, 5)])
+ prev_lab = _make_disc_label_image((64, 64), [(1, 50, 50, 5)])
+ mat_cy, _, _ = _run_cython_ioa_2d(lab, prev_lab, False)
+ np.testing.assert_array_equal(mat_cy, np.zeros_like(mat_cy))
+
+ def test_identical_images_diagonal_is_one(self):
+ """When lab == prev_lab and IDs match, area_prev IoA should be 1."""
+ lab = _make_disc_label_image((64, 64), [
+ (1, 20, 20, 8),
+ (2, 20, 50, 8),
+ ])
+ mat_cy, _, _ = _run_cython_ioa_2d(lab, lab, False)
+ np.testing.assert_allclose(np.diag(mat_cy), 1.0, rtol=1e-9)
+
+ def test_dtype_is_float64(self, label_2d, label_2d_shifted):
+ mat_cy, _, _ = _run_cython_ioa_2d(label_2d_shifted, label_2d, False)
+ assert mat_cy.dtype == np.float64
+
+
+# ---------------------------------------------------------------------------
+# calc_IoA_matrix_3D
+# ---------------------------------------------------------------------------
+
+def _run_cython_ioa_3d(lab, prev_lab, use_union):
+ rp = regionprops(lab.astype(np.int32))
+ prev_rp = regionprops(prev_lab.astype(np.int32))
+ curr_IDs_arr = np.array([obj.label for obj in rp], dtype=np.uint32)
+ prev_IDs_arr = np.array([obj.label for obj in prev_rp], dtype=np.uint32)
+ prev_areas_arr = np.array([obj.area for obj in prev_rp], dtype=np.uint32)
+ if use_union:
+ rp_mapper = {obj.label: obj for obj in rp}
+ curr_areas_arr = np.array(
+ [rp_mapper[ID].area for ID in curr_IDs_arr.tolist()], dtype=np.uint32
+ )
+ else:
+ curr_areas_arr = np.empty(0, dtype=np.uint32)
+ return (
+ calc_IoA_matrix_3D(
+ lab.astype(np.uint32), prev_lab.astype(np.uint32),
+ curr_IDs_arr, prev_IDs_arr,
+ prev_areas_arr, curr_areas_arr, use_union,
+ ),
+ rp, prev_rp,
+ )
+
+
+class TestCalcIoAMatrix3D:
+ @pytest.mark.parametrize("use_union", [False, True])
+ def test_matches_python_reference(self, label_3d, label_3d_shifted, use_union):
+ mat_cy, rp, prev_rp = _run_cython_ioa_3d(label_3d_shifted, label_3d, use_union)
+ mat_py, _, _ = _python_ioa_matrix(
+ label_3d_shifted, label_3d, rp, prev_rp, use_union
+ )
+ np.testing.assert_allclose(mat_cy, mat_py, rtol=1e-9, atol=1e-12,
+ err_msg=f"3D mismatch for use_union={use_union}")
+
+ def test_shape(self, label_3d, label_3d_shifted):
+ mat_cy, rp, prev_rp = _run_cython_ioa_3d(label_3d_shifted, label_3d, False)
+ assert mat_cy.shape == (len(rp), len(prev_rp))
+
+ def test_values_between_0_and_1(self, label_3d, label_3d_shifted):
+ mat_cy, _, _ = _run_cython_ioa_3d(label_3d_shifted, label_3d, False)
+ assert np.all(mat_cy >= 0.0)
+ assert np.all(mat_cy <= 1.0 + 1e-12)
+
+ def test_no_overlap_gives_zero_matrix(self):
+ lab = _make_sphere_label_image((20, 32, 32), [(1, 5, 8, 8, 2, 4, 4)])
+ prev_lab = _make_sphere_label_image((20, 32, 32), [(1, 15, 24, 24, 2, 4, 4)])
+ mat_cy, _, _ = _run_cython_ioa_3d(lab, prev_lab, False)
+ np.testing.assert_array_equal(mat_cy, np.zeros_like(mat_cy))
+
+ def test_identical_images_diagonal_is_one(self):
+ lab = _make_sphere_label_image((20, 32, 32), [
+ (1, 8, 10, 10, 3, 4, 4),
+ (2, 8, 10, 22, 3, 4, 4),
+ ])
+ mat_cy, _, _ = _run_cython_ioa_3d(lab, lab, False)
+ np.testing.assert_allclose(np.diag(mat_cy), 1.0, rtol=1e-9)
+
+ def test_dtype_is_float64(self, label_3d, label_3d_shifted):
+ mat_cy, _, _ = _run_cython_ioa_3d(label_3d_shifted, label_3d, False)
+ assert mat_cy.dtype == np.float64
diff --git a/tests/test_regionprops_cutout_update.py b/tests/test_regionprops_cutout_update.py
new file mode 100644
index 000000000..a5964fb48
--- /dev/null
+++ b/tests/test_regionprops_cutout_update.py
@@ -0,0 +1,328 @@
+import pytest
+import numpy as np
+
+from cellacdc.regionprops import acdcRegionprops
+
+
+def test_update_regionprops_via_cutout_reuses_cutout_object_with_global_coords():
+ old_lab = np.zeros((12, 12), dtype=np.uint16)
+ old_lab[1:3, 1:3] = 1
+
+ new_lab = old_lab.copy()
+ new_lab[5:7, 6:9] = 2
+
+ rp = acdcRegionprops(old_lab)
+ rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(4, 5, 9, 10))
+
+ obj = rp.get_obj_from_ID(2)
+
+ assert obj is not None
+ assert obj.bbox == (5, 6, 7, 9)
+ assert tuple((slc.start, slc.stop) for slc in obj.slice) == ((5, 7), (6, 9))
+ assert obj.centroid == (5.5, 7.0)
+ np.testing.assert_array_equal(
+ obj.coords,
+ np.array(
+ [
+ [5, 6],
+ [5, 7],
+ [5, 8],
+ [6, 6],
+ [6, 7],
+ [6, 8],
+ ]
+ ),
+ )
+ assert new_lab[obj.slice].shape == obj.image.shape
+ assert np.all(new_lab[obj.slice][obj.image] == 2)
+
+
+def test_update_regionprops_via_cutout_refreshes_preserved_id_image():
+ old_lab = np.zeros((10, 10), dtype=np.uint16)
+ old_lab[2:4, 2:4] = 1
+
+ new_lab = old_lab.copy()
+ new_lab[2:5, 2:5] = 1
+
+ rp = acdcRegionprops(old_lab)
+ rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(1, 1, 6, 6))
+
+ obj = rp.get_obj_from_ID(1)
+
+ assert obj is not None
+ assert obj.bbox == (2, 2, 5, 5)
+ np.testing.assert_array_equal(obj.image, new_lab[obj.slice] == 1)
+ np.testing.assert_array_equal(obj.coords, np.argwhere(new_lab == 1))
+
+
+def test_update_regionprops_via_cutout_batches_border_touching_ids():
+ old_lab = np.zeros((14, 14), dtype=np.uint16)
+ old_lab[1:3, 1:3] = 1
+
+ new_lab = old_lab.copy()
+ new_lab[4:9, 5:8] = 2
+ new_lab[7:11, 9:13] = 3
+
+ rp = acdcRegionprops(old_lab)
+ rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(6, 7, 10, 12))
+
+ obj2 = rp.get_obj_from_ID(2)
+ obj3 = rp.get_obj_from_ID(3)
+
+ assert obj2 is not None
+ assert obj2.bbox == (4, 5, 9, 8)
+ np.testing.assert_array_equal(obj2.coords, np.argwhere(new_lab == 2))
+
+ assert obj3 is not None
+ assert obj3.bbox == (7, 9, 11, 13)
+ np.testing.assert_array_equal(obj3.coords, np.argwhere(new_lab == 3))
+
+
+def test_update_regionprops_via_assignments_rebinds_label_image():
+ old_lab = np.zeros((8, 8), dtype=np.uint16)
+ old_lab[2:5, 3:6] = 1
+
+ new_lab = np.zeros_like(old_lab)
+ new_lab[2:5, 3:6] = 7
+
+ rp = acdcRegionprops(old_lab)
+ rp.update_regionprops_via_assignments({1: 7}, new_lab)
+
+ obj = rp.get_obj_from_ID(7)
+
+ assert obj is not None
+ assert rp.lab is new_lab
+ assert obj._label_image is new_lab
+ np.testing.assert_array_equal(obj.image, new_lab[obj.slice] == 7)
+ np.testing.assert_array_equal(obj.coords, np.argwhere(new_lab == 7))
+
+
+def test_slice_regionprops_are_lazy_and_initialized_on_access():
+ lab = np.zeros((3, 6, 6), dtype=np.uint16)
+ lab[1, 2:4, 1:3] = 4
+
+ rp = acdcRegionprops(lab)
+
+ assert rp._slice_rps['z'] == {}
+ assert rp._slice_rps['y'] == {}
+ assert rp._slice_rps['x'] == {}
+
+ z1 = rp.get_slice_rp(1)
+ assert z1 is not None
+ assert 1 in rp._slice_rps['z']
+ assert rp._slice_rps['z'][1] is z1
+ assert z1.lab.ndim == 2
+ assert z1.get_obj_from_ID(4) is not None
+
+
+def test_projection_regionprops_are_lazy_and_initialized_on_access():
+ lab = np.zeros((3, 6, 6), dtype=np.uint16)
+ lab[1, 2:4, 1:3] = 4
+
+ rp = acdcRegionprops(lab)
+
+ assert rp._proj_rps['z'] == {}
+ assert rp._proj_rps['y'] == {}
+ assert rp._proj_rps['x'] == {}
+
+ zmax = rp.get_proj_rp(kind='max', slicing='z')
+ assert zmax is not None
+ assert 'max' in rp._proj_rps['z']
+ assert rp._proj_rps['z']['max'] is zmax
+ assert zmax.lab.ndim == 2
+ assert zmax.get_obj_from_ID(4) is not None
+
+
+def test_projection_regionprops_support_most_common_kind():
+ lab = np.array(
+ [
+ [[0, 1], [2, 2]],
+ [[0, 1], [2, 3]],
+ [[4, 1], [0, 3]],
+ ],
+ dtype=np.uint16,
+ )
+
+ rp = acdcRegionprops(lab)
+ most_common = rp.get_proj_rp(kind='most common', slicing='z')
+
+ expected = np.array(
+ [[4, 1], [2, 3]],
+ dtype=np.uint16,
+ )
+ np.testing.assert_array_equal(most_common.lab, expected)
+ assert rp.get_obj_from_proj_rp(3, kind='most common z-projection', warn=False) is not None
+
+
+def test_most_common_projection_uses_local_cutout_update(monkeypatch):
+ old_lab = np.zeros((3, 6, 6), dtype=np.uint16)
+ old_lab[:, 1:4, 1:4] = 1
+
+ rp = acdcRegionprops(old_lab)
+ proj_before = rp.get_proj_rp(kind='most_common', slicing='z')
+ expected_before = rp._get_lab_projection(old_lab, slicing='z', kind='most_common')
+ np.testing.assert_array_equal(proj_before.lab, expected_before)
+
+ new_lab = old_lab.copy()
+ new_lab[0:2, 2:5, 2:5] = 2
+
+ original_replace_cached = rp._replace_cached_lab_projection
+
+ def _replace_cached_should_not_run_for_most_common(slicing, kind):
+ if kind == 'most_common':
+ raise AssertionError(
+ 'most_common projection should be updated locally for cutout updates.'
+ )
+ return original_replace_cached(slicing, kind)
+
+ monkeypatch.setattr(rp, '_replace_cached_lab_projection', _replace_cached_should_not_run_for_most_common)
+
+ rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(2, 2, 5, 5))
+
+ proj_after = rp.get_proj_rp(kind='most_common', slicing='z')
+ expected_after = rp._get_lab_projection(new_lab, slicing='z', kind='most_common')
+ np.testing.assert_array_equal(proj_after.lab, expected_after)
+
+
+@pytest.mark.parametrize('slicing', ['z', 'y', 'x'])
+def test_most_common_projection_uses_local_cutout_update_for_all_slicings(monkeypatch, slicing):
+ old_lab = np.zeros((4, 7, 8), dtype=np.uint16)
+ old_lab[1:3, 1:4, 1:4] = 1
+ old_lab[0:2, 4:6, 4:7] = 2
+
+ rp = acdcRegionprops(old_lab)
+ proj_before = rp.get_proj_rp(kind='most_common', slicing=slicing)
+ expected_before = rp._get_lab_projection(old_lab, slicing=slicing, kind='most_common')
+ np.testing.assert_array_equal(proj_before.lab, expected_before)
+
+ new_lab = old_lab.copy()
+ new_lab[:, 2:6, 3:7] = np.array(
+ [
+ [
+ [0, 2, 2, 0],
+ [3, 3, 2, 0],
+ [3, 3, 2, 0],
+ [0, 0, 0, 0],
+ ],
+ [
+ [0, 2, 2, 0],
+ [3, 3, 3, 0],
+ [3, 3, 3, 0],
+ [0, 0, 0, 0],
+ ],
+ [
+ [0, 0, 0, 0],
+ [3, 3, 3, 0],
+ [3, 3, 3, 0],
+ [0, 0, 0, 0],
+ ],
+ [
+ [0, 0, 0, 0],
+ [0, 0, 0, 0],
+ [0, 0, 0, 0],
+ [0, 0, 0, 0],
+ ],
+ ],
+ dtype=np.uint16,
+ )
+
+ original_replace_cached = rp._replace_cached_lab_projection
+
+ def _replace_cached_should_not_run_for_most_common(slicing_name, kind):
+ if kind == 'most_common':
+ raise AssertionError(
+ 'most_common projection should be updated locally for cutout updates.'
+ )
+ return original_replace_cached(slicing_name, kind)
+
+ monkeypatch.setattr(
+ rp,
+ '_replace_cached_lab_projection',
+ _replace_cached_should_not_run_for_most_common,
+ )
+
+ rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(2, 3, 6, 7))
+
+ proj_after = rp.get_proj_rp(kind='most_common', slicing=slicing)
+ expected_after = rp._get_lab_projection(new_lab, slicing=slicing, kind='most_common')
+ np.testing.assert_array_equal(proj_after.lab, expected_after)
+
+
+def test_get_obj_from_id_for_stored_slice_and_projection_rps():
+ lab = np.zeros((4, 8, 8), dtype=np.uint16)
+ lab[1:3, 2:5, 3:6] = 2
+
+ rp = acdcRegionprops(lab)
+
+ obj_slice = rp.get_obj_from_slice_rp(2, slice_number=1, slicing='z', warn=False)
+ assert obj_slice is not None
+
+ obj_proj = rp.get_obj_from_proj_rp(2, kind='max', slicing='z', warn=False)
+ assert obj_proj is not None
+
+
+def test_slice_regionprops_follow_assignments_and_deletions():
+ lab = np.zeros((4, 8, 8), dtype=np.uint16)
+ lab[1:3, 2:5, 3:6] = 2
+
+ rp = acdcRegionprops(lab)
+ _ = rp.get_slice_rp(1, 'z')
+ _ = rp.get_slice_rp(2, 'y')
+ _ = rp.get_slice_rp(3, 'x')
+ _ = rp.get_proj_rp('max', 'z')
+ _ = rp.get_proj_rp('mean', 'y')
+ _ = rp.get_proj_rp('median', 'x')
+
+ remapped_lab = np.zeros_like(lab)
+ remapped_lab[1:3, 2:5, 3:6] = 9
+ rp.update_regionprops_via_assignments({2: 9}, remapped_lab)
+
+ assert rp.get_slice_rp(1, 'z').get_obj_from_ID(9, warn=False) is not None
+ assert rp.get_slice_rp(2, 'y').get_obj_from_ID(9, warn=False) is not None
+ assert rp.get_slice_rp(3, 'x').get_obj_from_ID(9, warn=False) is not None
+ assert rp.get_proj_rp('max', 'z').get_obj_from_ID(9, warn=False) is not None
+ assert rp.get_slice_rp(1, 'z').get_obj_from_ID(2, warn=False) is None
+
+ rp.update_regionprops_via_deletions({9})
+ assert rp.get_slice_rp(1, 'z').get_obj_from_ID(9, warn=False) is None
+ assert rp.get_slice_rp(2, 'y').get_obj_from_ID(9, warn=False) is None
+ assert rp.get_slice_rp(3, 'x').get_obj_from_ID(9, warn=False) is None
+ assert rp.get_proj_rp('max', 'z').get_obj_from_ID(9, warn=False) is None
+
+
+def test_slice_regionprops_update_from_2d_cutout_on_3d():
+ old_lab = np.zeros((5, 12, 12), dtype=np.uint16)
+ old_lab[1:4, 2:4, 2:4] = 5
+
+ new_lab = old_lab.copy()
+ new_lab[3, 6:9, 7:10] = 6
+
+ rp = acdcRegionprops(old_lab)
+ _ = rp.get_slice_rp(3, 'z')
+ _ = rp.get_slice_rp(6, 'y')
+ _ = rp.get_slice_rp(7, 'x')
+
+ # 2D cutout on y/x; implementation expands to all z for touched IDs.
+ rp.update_regionprops_via_cutout(new_lab, cutout_bbox=(5, 6, 10, 11))
+
+ assert rp.get_obj_from_ID(6, warn=False) is not None
+ assert rp.get_slice_rp(3, 'z').get_obj_from_ID(6, warn=False) is not None
+ assert rp.get_slice_rp(6, 'y').get_obj_from_ID(6, warn=False) is not None
+ assert rp.get_slice_rp(7, 'x').get_obj_from_ID(6, warn=False) is not None
+
+
+def test_projection_lab_sorted_draws_large_first_small_on_top():
+ lab = np.zeros((4, 8, 8), dtype=np.uint16)
+ # Larger object
+ lab[0:4, 1:7, 1:7] = 1
+ # Smaller object overlapping the center
+ lab[1:3, 3:5, 3:5] = 2
+
+ rp = acdcRegionprops(lab)
+ proj = rp.get_projection_lab_sorted(slicing='z')
+
+ # Small object should be visible on top in overlap area.
+ np.testing.assert_array_equal(proj[3:5, 3:5], np.full((2, 2), 2, dtype=np.uint16))
+ assert proj[2, 2] == 1
+
+
diff --git a/tests/test_segm_utils.py b/tests/test_segm_utils.py
new file mode 100644
index 000000000..a8f4c850e
--- /dev/null
+++ b/tests/test_segm_utils.py
@@ -0,0 +1,175 @@
+import types
+
+import numpy as np
+from skimage.measure import regionprops
+from types import SimpleNamespace
+
+from cellacdc import myutils
+from cellacdc.regionprops import acdcRegionprops
+from cellacdc.workers import SegForLostIDsWorker
+
+from cellacdc.models.thresholding.acdcSegment import Model as ThresholdingModel
+from cellacdc.segm_utils import get_best_overlapping_label
+
+
+def test_get_best_overlapping_label_uses_majority_overlap_with_allowed_labels():
+ label_img = np.zeros((8, 8), dtype=np.uint16)
+ label_img[2:6, 2:4] = 4
+ label_img[3:7, 4:6] = 7
+
+ obj = types.SimpleNamespace(
+ slice=(slice(2, 7), slice(2, 6)),
+ image=np.array(
+ [
+ [False, False, False, False],
+ [True, True, False, False],
+ [True, True, True, True],
+ [True, True, True, True],
+ [False, False, True, True],
+ ]
+ ),
+ )
+
+ assert get_best_overlapping_label(label_img, obj, {4, 7}) == 4
+
+
+def test_get_best_overlapping_label_returns_none_without_allowed_overlap():
+ label_img = np.zeros((6, 6), dtype=np.uint16)
+ label_img[1:3, 1:3] = 2
+
+ obj = types.SimpleNamespace(
+ slice=(slice(1, 4), slice(1, 4)),
+ image=np.array(
+ [
+ [False, True, False],
+ [True, True, True],
+ [False, True, False],
+ ]
+ ),
+ )
+
+ assert get_best_overlapping_label(label_img, obj, {5}) is None
+
+
+def test_thresholding_model_object_can_be_mapped_back_to_missing_id():
+ prev_lab = np.zeros((10, 10), dtype=np.uint16)
+ prev_lab[3:7, 3:7] = 5
+
+ image = np.zeros((10, 10), dtype=np.float32)
+ image[3:7, 3:7] = 10.0
+
+ model = ThresholdingModel()
+ model_lab = model.segment(
+ image,
+ gauss_sigma=0,
+ threshold_method='threshold_otsu',
+ )
+
+ rp_model = regionprops(model_lab)
+ assert len(rp_model) == 1
+
+ recovered_id = get_best_overlapping_label(prev_lab, rp_model[0], {5})
+
+ assert recovered_id == 5
+
+
+class _DummyLogger:
+ def info(self, message):
+ pass
+
+ def warning(self, message):
+ pass
+
+ def error(self, message):
+ pass
+
+
+class _DummySignal:
+ def emit(self, *args, **kwargs):
+ pass
+
+
+class _DummySignals:
+ def __init__(self):
+ self.progress = _DummySignal()
+ self.finished = _DummySignal()
+ self.initProgressBar = _DummySignal()
+ self.progressBar = _DummySignal()
+ self.critical = _DummySignal()
+
+
+def test_seg_for_lost_ids_worker_thresholding_relabels_recovered_object(monkeypatch):
+ prev_lab = np.zeros((10, 10), dtype=np.uint16)
+ prev_lab[3:7, 3:7] = 5
+
+ curr_lab = np.zeros((10, 10), dtype=np.uint16)
+ curr_img = np.zeros((10, 10), dtype=np.float32)
+ curr_img[3:7, 3:7] = 10.0
+
+ prev_rp = acdcRegionprops(prev_lab)
+ curr_rp = acdcRegionprops(curr_lab)
+
+ posData = SimpleNamespace(
+ frame_i=1,
+ lab=curr_lab.copy(),
+ rp=curr_rp,
+ allData_li=[
+ {'labels': prev_lab, 'regionprops': prev_rp},
+ {'labels': curr_lab, 'regionprops': curr_rp},
+ ],
+ )
+
+ guiWin = SimpleNamespace(
+ data=[posData],
+ pos_i=0,
+ SegForLostIDsSettings={
+ 'models_settings': [
+ {
+ 'base_model_name': 'thresholding',
+ 'init_kwargs_new': {},
+ 'args_new': {
+ 'distance_filler_growth': 1.0,
+ 'overlap_threshold': 0.5,
+ 'padding': 1.0,
+ 'size_perc_diff': 1.0,
+ 'allow_only_tracked_cells': True,
+ },
+ 'init_kwargs': {},
+ 'model_kwargs': {
+ 'gauss_sigma': 0,
+ 'threshold_method': 'threshold_otsu',
+ },
+ 'preproc_recipe': None,
+ 'applyPostProcessing': False,
+ 'standardPostProcessKwargs': {},
+ 'customPostProcessFeatures': None,
+ 'customPostProcessGroupedFeatures': None,
+ }
+ ]
+ },
+ getDisplayedImg1=lambda: curr_img,
+ get_2Dlab=lambda lab: lab,
+ getTrackedLostIDs=lambda: [],
+ setBrushID=lambda useCurrentLab=True, return_val=True: 10,
+ logger=_DummyLogger(),
+ )
+
+ worker = SegForLostIDsWorker(guiWin, mutex=SimpleNamespace(lock=lambda: None, unlock=lambda: None), waitCond=SimpleNamespace(wait=lambda mutex: None))
+ worker.signals = _DummySignals()
+ worker.logger = _DummyLogger()
+ worker.gpu_go = True
+ worker.dont_force_cpu = True
+
+ monkeypatch.setattr(worker, 'emitSigAskInit', lambda: None)
+ monkeypatch.setattr(worker, 'emitSigAskInstallGPU', lambda base_model_name, use_gpu: None)
+ monkeypatch.setattr(worker, 'emitSigUpdateRP', lambda wl_update=True, wl_track_og_curr=False: None)
+ monkeypatch.setattr(worker, 'emitSigStoreData', lambda autosave=True: None)
+ monkeypatch.setattr(worker, 'emitTrackManuallyAddedObject', lambda *args, **kwargs: None)
+ monkeypatch.setattr(myutils, 'import_segment_module', lambda base_model_name: SimpleNamespace(Model=ThresholdingModel))
+ monkeypatch.setattr(myutils, 'init_segm_model', lambda acdcSegment, posData, init_kwargs_new: ThresholdingModel())
+
+ monkeypatch.setattr(worker, 'emitGetSegForLostIDsInputImg', lambda image_channel_name: curr_img)
+ worker.run()
+
+ assert posData.lab[3:7, 3:7].min() == 5
+ assert posData.lab[3:7, 3:7].max() == 5
\ No newline at end of file