Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions ClearMap/ImageProcessing/Experts/Cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Expert cell image processing pipeline.

This module provides the basic routines for processing immediate early
gene data.
gene data.

The routines are used in the :mod:`ClearMap.Scripts.CellMap` pipeline.
"""
Expand Down Expand Up @@ -82,7 +82,7 @@
shape=3,
measure=['source', 'background_correction']),
)
"""Parameter for the cell detection pipeline.
"""Parameter for the cell detection pipeline.
See :func:`detect_cells` for details."""


Expand All @@ -96,15 +96,15 @@
verbose=None,
processes=None
)
"""Parallel processing parameter for the cell detection pipeline.
See :func:`ClearMap.ParallelProcessing.BlockProcessing.process` for details."""
"""Parallel processing parameter for the cell detection pipeline.
See :func:`ClearMap.ParallelProcessing.BlockProcessing.process` for details."""


###############################################################################
# ## Cell detection
###############################################################################
def detect_cells(source, sink=None, cell_detection_parameter=default_cell_detection_parameter,

def detect_cells(source, sink=None, block_detection_function=None, cell_detection_parameter=default_cell_detection_parameter,
processing_parameter=default_cell_detection_processing_parameter, workspace=None):
"""Cell detection pipeline.

Expand All @@ -114,6 +114,10 @@ def detect_cells(source, sink=None, cell_detection_parameter=default_cell_detect
The source of the stitched raw data.
sink : sink specification or None
The sink to write the result to. If None, an array is returned.
block_detection_function : function or None
The function apply to each block. If function, it should return a tuple of arrays
with the following specs: ndims should be 2, dimension 1 of first array should be 3,
and dimension 0 of all arrays should be the equal.
cell_detection_parameter : dict
Parameter for the binarization. See below for details.
processing_parameter : dict
Expand Down Expand Up @@ -305,7 +309,11 @@ def detect_cells(source, sink=None, cell_detection_parameter=default_cell_detect
n_processes = multiprocessing.cpu_count() if processing_parameter.get('processes') is None else processing_parameter.get('processes')
n_threads = int(multiprocessing.cpu_count() / n_processes) # Number of threads so that * n_processes, fills CPUs

results, blocks = bp.process(detect_cells_block, source, sink=None, function_type='block', return_result=True,
if block_detection_function is None:
# use default cell detection function
block_detection_function = detect_cells_block

results, blocks = bp.process(block_detection_function, source, sink=None, function_type='block', return_result=True,
return_blocks=True, parameter=cell_detection_parameter, workspace=workspace,
**{**processing_parameter, **{'n_threads': n_threads}})

Expand Down Expand Up @@ -407,7 +415,7 @@ def detect_cells_block(source, parameter=default_cell_detection_parameter, n_thr
maxima_labels, _ = ndi.label(maxima, structure=np.ones((3,)*3,dtype='bool'))
centers = np.vstack(md.label_representatives(maxima_labels)).transpose()
# we could come back to the ancient version
# centers = ap.where(maxima, processes=n_threads).array
# centers = ap.where(maxima, processes=n_threads).array
del maxima

# correct for valid region
Expand Down Expand Up @@ -436,9 +444,9 @@ def detect_cells_block(source, parameter=default_cell_detection_parameter, n_thr
shape = None
else:
raise err





valid = sizes > 0
results += (sizes,)
else:
Expand Down Expand Up @@ -479,7 +487,7 @@ def detect_cells_block(source, parameter=default_cell_detection_parameter, n_thr

if parameter.get('verbose'):
total_time.print_elapsed_time('Cell detection')

gc.collect()

return results
Expand Down
138 changes: 138 additions & 0 deletions ClearMap/Scripts/custom_cells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# TODO module-level documentation: This module is an example of a custom pipeline
"""
Example usage:
```
from ClearMap.Scripts.custom_cells import detect_cells, custom_detect_cells_block, default_cell_detection_parameter
from skimage.data import cells3d

test_data = cells3d()[:,1]
sink = detect_cells(test_data, block_detection_function=custom_detect_cells_block, cell_detection_parameter=cell_detection_parameter)
```
"""

import gc

import numpy as np

from skimage.filters import threshold_otsu
from skimage.morphology import (
remove_small_objects,
remove_small_holes,
binary_closing,
ball,
)
from skimage.measure import label, regionprops

from ClearMap.ImageProcessing.Experts.Cells import (
default_cell_detection_parameter,
detect_cells,
) # We need `detect_cells` imported from the same module as `custom_detect_cells_block`
import ClearMap.Utils.Timer as tmr
from ClearMap.ImageProcessing.Experts.utils import run_step


default_cell_detection_parameter["custom_step"] = {
"save": False
} # The `cell_detection_parameter` dictionnary needs a key with the name of the new custom step


def custom_detection(array, **kwargs):
"""Detect cells in a 3D array.
Outputs a tuple of arrays representing the coordinates (x, y, z) of the cells **in the array**,
and properties of each detected cell."""

## detect intensities above otsu threshold
mask = array > threshold_otsu(array)

## postprocess mask
mask = remove_small_objects(mask, min_size=1000)
mask = binary_closing(mask, ball(1))
mask = remove_small_holes(mask, area_threshold=10_000)

## extract individual connected components
labeled_mask = label(mask.astype(int), connectivity=1)

## compute properties of the connected components
region_properties = regionprops(labeled_mask, array)
coordinates = np.array([region.centroid for region in region_properties])
intensities = np.array([region.mean_intensity for region in region_properties])
maxima = np.array([region.max_intensity for region in region_properties])
shapes = np.array([region.area for region in region_properties])

return coordinates, intensities, maxima, shapes


def custom_detect_cells_block(source, parameter=default_cell_detection_parameter, n_threads=None):
"""Detect cells in a block of a 3D image.
Outputs a tuple of arrays, representing the coordinates (x, y, z) of the cells **in the image**,
and properties of each detected cell.
That function can be passed to CellMap's `detect_cells` function."""

# initialize parameter and slicing
if parameter.get("verbose"):
prefix = "Block %s: " % (source.info(),)
total_time = tmr.Timer(prefix)
else:
prefix = ""

base_slicing = source.valid.base_slicing
valid_slicing = source.valid.slicing
valid_lower = source.valid.lower
valid_upper = source.valid.upper
lower = source.lower

# Measurements that will be performed per cell
steps_to_measure = {} # FIXME: rename

## intensity
parameter_intensity = parameter.get("intensity_detection")
if parameter_intensity:
parameter_intensity = parameter_intensity.copy()
measure = parameter_intensity.pop("measure", [])
measure = measure if measure else []
## validation
valid_measurement_keys = list(default_cell_detection_parameter.keys()) + [
"source"
]
for m in measure:
if m not in valid_measurement_keys:
raise KeyError(f"Unknown measurement: {m}")
steps_to_measure[m] = None
## in case source is measured, the image used is source
if "source" in steps_to_measure:
steps_to_measure["source"] = np.array(source.array)
## other cases seem to not be supported

step_params = {
"parameter": parameter,
"steps_to_measure": steps_to_measure,
"prefix": prefix,
"base_slicing": base_slicing,
"valid_slicing": valid_slicing,
}

# WARNING: if param_illumination: previous_step = source, not np.array(source.array)

results = run_step(
"custom_step", np.array(source.array), custom_detection, **step_params
)

# correct coordinate offsets of blocks
results = (results[0] + lower,) + results[1:]

# remove cells outside the valid region of the block
valid_mask = np.all(
(results[0] >= valid_lower) & (results[0] < valid_upper), axis=1
)
results = tuple(result[valid_mask] for result in results)

# ensure all results array have 2 dimensions, so that they are ready to be vstacked
results = tuple(
result[:, None] if result.ndim == 1 else result for result in results
)

if parameter.get("verbose"):
total_time.print_elapsed_time("Cell detection")

gc.collect()
return results
Loading