Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 97 additions & 8 deletions src/cellmap_analyze/process/skeletonize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pandas as pd
import networkx as nx
import edt as edt_module
from cellmap_analyze.util import dask_util
from cellmap_analyze.util.image_data_interface import ImageDataInterface
from cellmap_analyze.util.mixins import ComputeConfigMixin
Expand Down Expand Up @@ -70,6 +72,15 @@ def __init__(
logger.info(f"Loaded {len(self.ids)} IDs from {csv_path}")
logger.info(f"Output will be written to {output_path}")

@staticmethod
def _empty_metrics():
return {
"longest_shortest_path_nm": 0.0,
"num_branches": 0,
"radius_mean_nm": np.nan,
"radius_std_nm": np.nan,
}

@staticmethod
def calculate_id_skeleton(
id_value,
Expand Down Expand Up @@ -120,7 +131,10 @@ def calculate_id_skeleton(
# Check if there's any data
if not np.any(data):
logger.warning(f"No voxels found for ID {id_value}, skipping")
return
return Skeletonize._empty_metrics()

# Compute EDT on pre-erosion mask for approximate radii
distance_transform = edt_module.edt(data, anisotropy=tuple(voxel_size))

# Apply erosion if requested
if erosion:
Expand All @@ -147,22 +161,32 @@ def calculate_id_skeleton(
simplified_path = f"{output_path}/simplified/{id_value}"
empty_skeleton.write_neuroglancer_skeleton(simplified_path)
logger.info(f"Wrote empty skeleton for ID {id_value}")
return
return Skeletonize._empty_metrics()

# Skeletonize
# Skeletonize using Lee's algorithm (skimage default). It has
# known limitations (e.g. thin structures may lose branches) but
# is sufficient for now.
skel = skeletonize(data)

# Check if skeletonization produced anything
if not np.any(skel):
logger.warning(f"Skeletonization produced no voxels for ID {id_value}, writing empty skeleton")
logger.warning(
f"Skeletonization produced no voxels for ID {id_value}, writing empty skeleton"
)
# Write empty skeleton files
empty_skeleton = CustomSkeleton(vertices=[], edges=[])
full_path = f"{output_path}/full/{id_value}"
empty_skeleton.write_neuroglancer_skeleton(full_path)
simplified_path = f"{output_path}/simplified/{id_value}"
empty_skeleton.write_neuroglancer_skeleton(simplified_path)
logger.info(f"Wrote empty skeleton for ID {id_value}")
return
return Skeletonize._empty_metrics()

# Sample radii at skeleton voxel positions
skel_coords = np.argwhere(skel)
radii = distance_transform[
skel_coords[:, 0], skel_coords[:, 1], skel_coords[:, 2]
]

# Convert to custom skeleton format
# spacing parameter scales the vertices by voxel_size
Expand Down Expand Up @@ -196,6 +220,35 @@ def calculate_id_skeleton(
else:
pruned = skeleton

# Compute skeleton metrics on pruned skeleton
num_branches = len(pruned.polylines)
longest_shortest_path = 0.0
if len(pruned.vertices) > 1:
pruned_graph = pruned.skeleton_to_graph()
for component in nx.connected_components(pruned_graph):
if len(component) < 2:
continue
subgraph = pruned_graph.subgraph(component)
start = next(iter(component))
lengths = nx.single_source_dijkstra_path_length(
subgraph, start, weight="weight"
)
far_node = max(lengths, key=lengths.get)
lengths2 = nx.single_source_dijkstra_path_length(
subgraph, far_node, weight="weight"
)
component_diameter = max(lengths2.values())
longest_shortest_path = max(
longest_shortest_path, component_diameter
)

skeleton_metrics = {
"longest_shortest_path_nm": longest_shortest_path,
"num_branches": num_branches,
"radius_mean_nm": float(np.mean(radii)),
"radius_std_nm": float(np.std(radii)),
}

# Simplify
if tolerance_nm > 0:
simplified = pruned.simplify(tolerance_nm)
Expand All @@ -214,7 +267,7 @@ def calculate_id_skeleton(
simplified_path = f"{output_path}/simplified/{id_value}"
empty_skeleton.write_neuroglancer_skeleton(simplified_path)
logger.info(f"Wrote empty skeleton for ID {id_value}")
return
return Skeletonize._empty_metrics()

# Ensure edges are properly shaped numpy arrays before writing
# Handle case where there are no edges (single vertex)
Expand Down Expand Up @@ -242,6 +295,8 @@ def calculate_id_skeleton(
f"Wrote simplified skeleton for ID {id_value}: {len(simplified.vertices)} vertices"
)

return skeleton_metrics

except Exception as e:
logger.error(f"Error processing ID {id_value}: {e}", exc_info=True)
raise
Expand Down Expand Up @@ -313,18 +368,48 @@ def skeletonize(self):

# Parallelize over IDs using dask
num_ids = len(self.ids)
tmp_merge_dir = f"{self.output_path}/_tmp_skeleton_metrics_to_merge"

dask_util.compute_blockwise_partitions(
skeleton_metrics = dask_util.compute_blockwise_partitions(
num_ids,
self.num_workers,
self.compute_args,
logger,
f"skeletonizing {num_ids} IDs from {self.segmentation_idi.path}",
self._skeletonize_id_wrapper,
merge_info=(Skeletonize._merge_skeleton_metrics, tmp_merge_dir),
)

self._write_skeleton_csv(skeleton_metrics)

logger.info("Skeletonization complete")

@staticmethod
def _merge_skeleton_metrics(list_of_results):
merged = []
for result in list_of_results:
merged.append(result)
return merged

def _write_skeleton_csv(self, skeleton_metrics):
original_df = pd.read_csv(self.csv_path, index_col=0)
metrics_df = pd.DataFrame(skeleton_metrics)
metrics_df = metrics_df.set_index("id")
metrics_df = metrics_df.rename(
columns={
"longest_shortest_path_nm": "Longest Shortest Path (nm)",
"num_branches": "Number of Branches",
"radius_mean_nm": "Radius Mean (nm)",
"radius_std_nm": "Radius Std (nm)",
}
)
combined_df = original_df.join(metrics_df)
csv_dir = os.path.dirname(self.csv_path)
csv_basename = os.path.splitext(os.path.basename(self.csv_path))[0]
output_csv = os.path.join(csv_dir, f"{csv_basename}_with_skeletons.csv")
combined_df.to_csv(output_csv)
logger.info(f"Wrote skeleton metrics CSV to {output_csv}")

def _skeletonize_id_wrapper(self, index):
"""
Wrapper to call calculate_id_skeleton with the appropriate ID.
Expand All @@ -333,7 +418,7 @@ def _skeletonize_id_wrapper(self, index):
index: Index into self.ids list
"""
id_value = self.ids[index]
Skeletonize.calculate_id_skeleton(
result = Skeletonize.calculate_id_skeleton(
id_value,
self.segmentation_idi,
self.bbox_df,
Expand All @@ -342,3 +427,7 @@ def _skeletonize_id_wrapper(self, index):
self.min_branch_length_nm,
self.tolerance_nm,
)
if result is None:
result = Skeletonize._empty_metrics()
result["id"] = id_value
return result
15 changes: 8 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,14 @@ def segmentation_for_skeleton():
) ** 2 <= radius**2
seg[sphere_mask] = 4

# ID 5: Cross/plus shape (3D cross with branches in X, Y, Z directions)
# Center bar along Z
seg[35:45, 20:22, 20:22] = 5
# Branch along X
seg[39:41, 20:22, 15:25] = 5
# Branch along Y
seg[39:41, 15:25, 20:22] = 5
# ID 5: Cross/plus shape (3D cross with branches in Z, X, Y directions)
# Thick arms (4x4 cross-section) so skeletonize produces real branches.
# Z arm (top/bottom of junction): z=30..49, 20 voxels long
seg[30:50, 23:27, 23:27] = 5
# X arm (longest): z=38..41, x=10..39, 30 voxels long
seg[38:42, 23:27, 10:40] = 5
# Y arm: z=38..41, y=14..35, 22 voxels long
seg[38:42, 14:36, 23:27] = 5

# ID 6: L-shaped structure
seg[5:10, 30:35, 30:32] = 6 # Vertical part
Expand Down
20 changes: 6 additions & 14 deletions tests/operations/test_morphological_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
import fastmorph
import numpy as np
from cellmap_analyze.util.image_data_interface import (
ImageDataInterface,
)
ImageDataInterface,
)


@pytest.mark.parametrize("operation", ["erosion", "dilation"])
def test_morphological_operations(tmp_zarr, segmentation_cylinders, operation):
num_iterations=1
num_iterations = 1
mo = MorphologicalOperations(
input_path=f"{tmp_zarr}/segmentation_cylinders/s0",
output_path=f"{tmp_zarr}/test_morphological_{operation}",
num_workers=1,
operation=operation,
iterations=num_iterations
iterations=num_iterations,
)
mo.perform_morphological_operation()
test_data = ImageDataInterface(
Expand All @@ -24,16 +24,8 @@ def test_morphological_operations(tmp_zarr, segmentation_cylinders, operation):
ground_truth = segmentation_cylinders.copy()

if operation == "erosion":
ground_truth = fastmorph.erode(ground_truth,iterations=num_iterations)
ground_truth = fastmorph.erode(ground_truth, iterations=num_iterations)
else:
ground_truth = fastmorph.dilate(ground_truth,iterations=num_iterations)
ground_truth = fastmorph.dilate(ground_truth, iterations=num_iterations)

assert np.array_equal(test_data, ground_truth)
# %%
# import fastmorph
# from cellmap_analyze.util.image_data_interface import (
# ImageDataInterface,
# )
# from cellmap_analyze.util.neuroglancer_util import view_in_neuroglancer
# view_in_neuroglancer(original=ImageDataInterface("/tmp/pytest-of-ackermand/pytest-29/tmp0/tmp.zarr/image_with_holes/s0").to_ndarray_ts(),gt=fastmorph.dilate(ImageDataInterface("/tmp/pytest-of-ackermand/pytest-29/tmp0/tmp.zarr/image_with_holes/s0").to_ndarray_ts(),iterations=5),test=ImageDataInterface("/tmp/pytest-of-ackermand/pytest-29/tmp0/tmp.zarr/test_morphological_dilation/s0").to_ndarray_ts())
# %%
25 changes: 25 additions & 0 deletions tests/operations/test_skeletonize.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,31 @@ def test_skeletonize_without_erosion(tmp_zarr, tmp_skeletonize_csv):
len(full_verts) > 0
), f"ID {id_val}: No vertices in skeleton without erosion"

# Verify skeleton metrics CSV was written with expected columns
csv_dir = os.path.dirname(tmp_skeletonize_csv)
csv_basename = os.path.splitext(os.path.basename(tmp_skeletonize_csv))[0]
metrics_csv_path = os.path.join(csv_dir, f"{csv_basename}_with_skeletons.csv")
assert os.path.exists(metrics_csv_path), "Skeleton metrics CSV not created"
metrics_df = pd.read_csv(metrics_csv_path, index_col=0)

# ID 5 (cross shape) should have meaningful skeleton metrics
row5 = metrics_df.loc[5]

# Cross has 3 arms meeting at a junction -> exactly 3 branches
assert (
row5["Number of Branches"] == 3
), f"Cross (ID 5) should have 3 branches, got {row5['Number of Branches']}"

# Longest shortest path should be approximately 160 nm
assert (
abs(row5["Longest Shortest Path (nm)"] - 160) < 20
), f"Cross (ID 5) longest shortest path should be ~160 nm, got {row5['Longest Shortest Path (nm)']}"

# Radii should be approximately 16 nm (arms are 4 voxels wide, voxel_size=8nm)
assert (
abs(row5["Radius Mean (nm)"] - 16) < 4
), f"Cross (ID 5) radius mean should be ~16 nm, got {row5['Radius Mean (nm)']}"


def test_skeletonize_with_pruning_and_simplification(tmp_zarr, tmp_skeletonize_csv):
"""Test that pruning and simplification reduce the skeleton complexity."""
Expand Down
Loading