From 2f6ec7949e9614e096635d3c85898e3af0148e75 Mon Sep 17 00:00:00 2001 From: David Ackerman Date: Tue, 5 Aug 2025 11:34:02 -0400 Subject: [PATCH 1/4] Fix block shape calculation and fix channel calculation for dacapo affinities --- cellmap_flow/utils/data.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cellmap_flow/utils/data.py b/cellmap_flow/utils/data.py index c6e1c42..1313ec3 100644 --- a/cellmap_flow/utils/data.py +++ b/cellmap_flow/utils/data.py @@ -122,7 +122,7 @@ def _get_config(self): config.output_channels = len( config.channels ) # 0:all_mem,1:organelle,2:mito,3:er,4:nucleus,5:pm,6:vs,7:ld - config.block_shape = np.array(tuple(out_shape) + (len(channels),)) + config.block_shape = np.array(tuple(out_shape) + (config.output_channels,)) return config @@ -384,7 +384,8 @@ def get_dacapo_channels(task): if hasattr(task, "channels"): return task.channels elif type(task).__name__ == "AffinitiesTask": - return ["x", "y", "z"] + # to be backwards compatible in case .channels or .neighborhood doesn't exist + return [f"aff_{'.'.join(map(str, n))}" for n in task.predictor.neighborhood] else: return ["membrane"] @@ -636,3 +637,11 @@ def _get_config(self) -> Config: config.model.to(device) config.model.eval() return config + + +# %% +# config, task = DaCapoModelConfig( +# run_name="finetuned_07-31-24_3d_lsdaffs_jrc_mus-liver-zon-1_07-31-24_nuclear_pores_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__0", +# iteration=55000, +# name="Example DaCapo Model", +# )._get_config() From 8663ac0b5cc270753625773dd70f1b458b0f2159 Mon Sep 17 00:00:00 2001 From: David Ackerman Date: Tue, 5 Aug 2025 11:34:19 -0400 Subject: [PATCH 2/4] Refactor AffinityPostprocessor to include filtering of fragments based on mean affinity values and update initialization parameters for bias and filtering. --- cellmap_flow/post/postprocessors.py | 111 +++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 19 deletions(-) diff --git a/cellmap_flow/post/postprocessors.py b/cellmap_flow/post/postprocessors.py index ef814c9..ccb377e 100644 --- a/cellmap_flow/post/postprocessors.py +++ b/cellmap_flow/post/postprocessors.py @@ -7,6 +7,7 @@ import threading from scipy.ndimage import label import mwatershed as mws +from scipy.ndimage import measurements, gaussian_filter from scipy.ndimage import measurements import fastremap from funlib.math import cantor_number @@ -131,7 +132,9 @@ def is_segmentation(self): class AffinityPostprocessor(PostProcessor): def __init__( self, - bias: float = 0.0, + adjacent_edge_bias: float = -0.4, + lr_bias_ratio: float = -0.175, + filter_val: float = 0.5, neighborhood: str = """[ [1, 0, 0], [0, 1, 0], @@ -145,36 +148,106 @@ def __init__( ]""", ): use_exact = "True" - self.bias = float(bias) + self.adjacent_edge_bias = float(adjacent_edge_bias) + self.lr_bias_ratio = float(lr_bias_ratio) + self.filter_val = float(filter_val) self.neighborhood = ast.literal_eval(neighborhood) - self.use_exact = use_exact == "True" + self.use_exact = use_exact == "False" self.num_previous_segments = 0 - def _process(self, data, chunk_num_voxels, chunk_corner): - data = data / 255.0 - n_channels = data.shape[0] - self.neighborhood = self.neighborhood[:n_channels] - # raise Exception(data.max(), data.min(), self.neighborhood) + import numpy as np + from scipy.ndimage import measurements - segmentation = mws.agglom( - data.astype(np.float64) - self.bias, - self.neighborhood, - ) + def filter_fragments( + self, affs_data: np.ndarray, fragments_data: np.ndarray, filter_val: float + ) -> None: + """Allows filtering of MWS fragments based on mean value of affinities & fragments. Will filter and update the fragment array in-place. - # filter fragments - average_affs = np.mean(data, axis=0) + Args: + aff_data (``np.ndarray``): + An array containing affinity data. + + fragments_data (``np.ndarray``): + An array containing fragment data. + + filter_val (``float``): + Threshold to filter if the average value falls below. + """ - filtered_fragments = [] + average_affs: float = np.mean(affs_data.data, axis=0) - fragment_ids = fastremap.unique(segmentation[segmentation > 0]) + filtered_fragments: list = [] + + fragment_ids: np.ndarray = np.unique(fragments_data) for fragment, mean in zip( - fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) + fragment_ids, measurements.mean(average_affs, fragments_data, fragment_ids) ): - if mean >= self.bias: + if mean < filter_val: filtered_fragments.append(fragment) - fastremap.mask_except(segmentation, filtered_fragments, in_place=True) + filtered_fragments: np.ndarray = np.array( + filtered_fragments, dtype=fragments_data.dtype + ) + # replace: np.ndarray = np.zeros_like(filtered_fragments) + fastremap.mask(fragments_data, filtered_fragments, in_place=True) + + def _process(self, data, chunk_num_voxels, chunk_corner): + data[data < self.filter_val] = 0 + if data.dtype == np.uint8: + logger.info("Assuming affinities are in [0,255]") + max_affinity_value: float = 255.0 + data = data.astype(np.float64) + else: + data = data.astype(np.float64) + max_affinity_value: float = 1.0 + + data /= max_affinity_value + + if data.max() < 1e-4: + segmentation = np.zeros( + data.shape, dtype=np.uint64 if self.use_exact else np.uint16 + ) + return np.expand_dims(segmentation, axis=0) + + channels = [ + channel for channel, ntp in enumerate(self.neighborhood) if ntp is not None + ] + neighborhood = [self.neighborhood[channel] for channel in channels] + + data = data[channels] + random_noise: float = np.random.randn(*data.shape) * 0.0001 + smoothed_affs: np.ndarray = ( + gaussian_filter(data, sigma=(0, *(np.amax(neighborhood, axis=0) / 3))) - 0.5 + ) * 0.001 + shift: np.ndarray = np.array( + [ + ( + self.adjacent_edge_bias + if max(offset) <= 1 + else np.linalg.norm(offset) * self.lr_bias_ratio + ) + for offset in neighborhood + ] + ).reshape((-1, *((1,) * (len(data.shape) - 1)))) + + # raise Exception(data.max(), data.min(), self.neighborhood) + + # segmentation = mws.agglom( + # data.astype(np.float64) - self.bias, + # self.neighborhood, + # ) + + # filter fragments + segmentation = mws.agglom( + data + shift + random_noise + smoothed_affs, + offsets=neighborhood, + ) + if self.filter_val > 0.0: + self.filter_fragments(data, segmentation, self.filter_val) + + # fragment_ids = fastremap.unique(segmentation[segmentation > 0]) + # fastremap.mask_except(segmentation, filtered_fragments, in_place=True) fastremap.renumber(segmentation, in_place=True) unique_increment = chunk_num_voxels * pymorton.interleave(*chunk_corner) if not self.use_exact: From 7c83d09533e248ea48552cf8e545066e0aba711a Mon Sep 17 00:00:00 2001 From: David Ackerman Date: Tue, 5 Aug 2025 13:27:35 -0400 Subject: [PATCH 3/4] remove comments --- cellmap_flow/utils/data.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cellmap_flow/utils/data.py b/cellmap_flow/utils/data.py index 1313ec3..59c2b05 100644 --- a/cellmap_flow/utils/data.py +++ b/cellmap_flow/utils/data.py @@ -637,11 +637,3 @@ def _get_config(self) -> Config: config.model.to(device) config.model.eval() return config - - -# %% -# config, task = DaCapoModelConfig( -# run_name="finetuned_07-31-24_3d_lsdaffs_jrc_mus-liver-zon-1_07-31-24_nuclear_pores_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__0", -# iteration=55000, -# name="Example DaCapo Model", -# )._get_config() From 59fc08dbe909818393963f2b6785f095975e55fd Mon Sep 17 00:00:00 2001 From: davidackerman Date: Tue, 5 Aug 2025 13:41:20 -0400 Subject: [PATCH 4/4] Update cellmap_flow/post/postprocessors.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cellmap_flow/post/postprocessors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cellmap_flow/post/postprocessors.py b/cellmap_flow/post/postprocessors.py index ccb377e..47bcafc 100644 --- a/cellmap_flow/post/postprocessors.py +++ b/cellmap_flow/post/postprocessors.py @@ -8,7 +8,6 @@ from scipy.ndimage import label import mwatershed as mws from scipy.ndimage import measurements, gaussian_filter -from scipy.ndimage import measurements import fastremap from funlib.math import cantor_number import fastmorph