From d6a932dbed5a85a98d3855758de60f771a8ab839 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Tue, 10 Feb 2026 11:45:57 -0500 Subject: [PATCH 01/13] Add first version of creep-correction translated from Ken's py script --- .../multisem/CreepCorrectionClient.java | 465 ++++++++++++++++++ .../multisem/CreepCorrectionSparkClient.java | 162 ++++++ 2 files changed, 627 insertions(+) create mode 100644 render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java create mode 100644 render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java new file mode 100644 index 000000000..c9a7eaa2b --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -0,0 +1,465 @@ +package org.janelia.render.client.multisem; + +import static org.janelia.alignment.spec.ResolvedTileSpecCollection.TransformApplicationMethod.INSERT_BEFORE_LAST; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.janelia.alignment.match.CanvasMatchResult; +import org.janelia.alignment.match.CanvasMatches; +import org.janelia.alignment.multisem.MultiSemUtilities; +import org.janelia.alignment.spec.LeafTransformSpec; +import org.janelia.alignment.spec.LayoutData; +import org.janelia.alignment.spec.ResolvedTileSpecCollection; +import org.janelia.alignment.spec.TileSpec; +import org.janelia.alignment.spec.TransformSpec; +import org.janelia.render.client.RenderDataClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import mpicbg.models.AffineModel2D; +import mpicbg.models.NotEnoughDataPointsException; +import mpicbg.models.PointMatch; + +/** + * Estimates and applies piezo creep correction for multi-SEM sFOV tiles within each mFOV. + * + *

The correction is estimated by fitting pairwise affine transforms between geometrically + * adjacent sFOV pairs and extracting the y-stretch factor. A double-exponential correction + * (using {@link org.janelia.alignment.transform.SEMDistortionTransformA}) is then derived + * and inserted into each tile's transform list.

+ * + * @author Michael Innerberger + */ +public class CreepCorrectionClient { + + // physical constants for the double-exponential model + private static final double TAU_0 = 0.42; // seconds + private static final double TAU_1 = 4.0; // seconds + private static final double PIXEL_DWELL = 800e-9; // seconds + private static final double LINE_WIDTH = 2000; // pixels + private static final double L0 = TAU_0 / (PIXEL_DWELL * LINE_WIDTH); // 262500 lines + private static final double L1 = TAU_1 / (PIXEL_DWELL * LINE_WIDTH); // 2500000 lines + + // line numbers used for stretch-to-amplitude conversion + private static final int TOP_LINE = 50; + private static final int BOTTOM_LINE = 1700; + + // validation thresholds + private static final double MIN_STRETCH = 0.9; + private static final double MAX_STRETCH = 1.1; + private static final int MIN_VALID_STRETCHES = 10; + private static final double MAX_STRETCH_STDDEV = 0.02; + + // RANSAC parameters for pairwise affine estimation + private static final int RANSAC_ITERATIONS = 1000; + private static final double RANSAC_MAX_EPSILON = 10.0; + private static final double RANSAC_MIN_INLIER_RATIO = 0.1; + private static final int RANSAC_MIN_NUM_INLIERS = 7; + private static final double RANSAC_MAX_TRUST = 3.0; + + // minimum offset for geometric neighbor detection (in stage coordinate units) + private static final double MIN_NEIGHBOR_OFFSET = 100.0; + + /** + * Processes all mFOVs for a given z-layer: loads tiles and matches, estimates creep correction, + * applies it where valid, and saves the results to the target stack. + * + * @return number of tiles processed + */ + public int processZLayer(final double z, + final RenderDataClient renderDataClient, + final RenderDataClient matchDataClient, + final String stack, + final String targetStack) + throws IOException { + + LOG.info("processZLayer: entry, z={}", z); + + final ResolvedTileSpecCollection resolvedTiles = renderDataClient.getResolvedTiles(stack, z); + + // group tiles by mFOV + final Map> mfovToTiles = new HashMap<>(); + for (final TileSpec tileSpec : resolvedTiles.getTileSpecs()) { + final String mfov = MultiSemUtilities.getMagcMfovForTileId(tileSpec.getTileId()); + mfovToTiles.computeIfAbsent(mfov, k -> new ArrayList<>()).add(tileSpec); + } + + // load all within-group matches and index by tile pair + final String groupId = String.valueOf(z); + final List allMatches = matchDataClient.getMatchesWithinGroup(groupId, false); + final Map pairKeyToMatches = new HashMap<>(); + for (final CanvasMatches match : allMatches) { + final String pMFOV = MultiSemUtilities.getMagcMfovForTileId(match.getpId()); + final String qMFOV = MultiSemUtilities.getMagcMfovForTileId(match.getqId()); + if (pMFOV.equals(qMFOV)) { + pairKeyToMatches.put(pairKey(match.getpId(), match.getqId()), match); + } + } + + // process each mFOV independently + int correctedMFOVCount = 0; + int skippedMFOVCount = 0; + for (final Map.Entry> entry : mfovToTiles.entrySet()) { + final String mfov = entry.getKey(); + final List mfovTiles = entry.getValue(); + + final boolean corrected = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); + if (corrected) { + correctedMFOVCount++; + } else { + skippedMFOVCount++; + } + } + + // save all tiles (both corrected and uncorrected) + renderDataClient.saveResolvedTiles(resolvedTiles, targetStack, z); + + LOG.info("processZLayer: exit, z={}, correctedMFOVs={}, skippedMFOVs={}, totalTiles={}", + z, correctedMFOVCount, skippedMFOVCount, resolvedTiles.getTileCount()); + + return resolvedTiles.getTileCount(); + } + + /** + * Processes a single mFOV: finds neighbor pairs, estimates stretch, validates, and applies correction. + * + * @return true if correction was applied, false if skipped + */ + boolean processMFOV(final String mfov, + final List mfovTiles, + final Map pairKeyToMatches, + final ResolvedTileSpecCollection resolvedTiles) { + + // find geometric neighbor pairs + final List neighborPairs = findGeometricNeighborPairs(mfovTiles); + + if (neighborPairs.isEmpty()) { + LOG.warn("processMFOV: no geometric neighbor pairs found for mFOV {}, skipping", mfov); + return false; + } + + // estimate y-stretch for each pair that has matches + final List stretches = new ArrayList<>(); + int pairsWithoutMatches = 0; + for (final TilePair pair : neighborPairs) { + final String key = pairKey(pair.pTileId, pair.qTileId); + final CanvasMatches matches = pairKeyToMatches.get(key); + + if (matches == null) { + pairsWithoutMatches++; + continue; + } + + final double yStretch = estimateYStretchForPair( + CanvasMatchResult.convertMatchesToPointMatchList(matches.getMatches())); + stretches.add(yStretch); + } + + if (pairsWithoutMatches > 0) { + LOG.info("processMFOV: {} of {} neighbor pairs in mFOV {} had no matches", + pairsWithoutMatches, neighborPairs.size(), mfov); + } + + // validate results + final ValidationResult validation = validateResults(stretches, mfov); + + if (!validation.isValid) { + LOG.warn("processMFOV: skipping mFOV {} - {}", mfov, validation.diagnosticMessage); + return false; + } + + // build and apply correction transform + final TransformSpec correctionSpec = buildCorrectionTransformSpec(validation.amplitude); + final Set mfovTileIds = mfovTiles.stream() + .map(TileSpec::getTileId) + .collect(Collectors.toSet()); + applyCorrectionToMFOV(resolvedTiles, mfovTileIds, correctionSpec); + + LOG.info("processMFOV: applied creep correction to mFOV {} with amplitude={}, medianStretch={}, stddev={}", + mfov, validation.amplitude, validation.medianStretch, validation.stddev); + + return true; + } + + /** + * Finds geometric neighbor pairs (lower-right and lower-left) using stage positions, + * replicating the Python prototype's neighbor-finding logic. + */ + List findGeometricNeighborPairs(final List mfovTiles) { + final List pairs = new ArrayList<>(); + + for (final TileSpec targetTile : mfovTiles) { + final StageCoordinates target = getStageCoordinates(targetTile); + if (target == null) continue; + + // find closest lower-right neighbor + findClosestNeighbor(targetTile.getTileId(), target, mfovTiles, true) + .ifPresent(pairs::add); + + // find closest lower-left neighbor + findClosestNeighbor(targetTile.getTileId(), target, mfovTiles, false) + .ifPresent(pairs::add); + } + + return pairs; + } + + private static StageCoordinates getStageCoordinates(final TileSpec targetTile) { + final LayoutData targetLayout = targetTile.getLayout(); + if (targetLayout == null || targetLayout.getStageX() == null || targetLayout.getStageY() == null) { + return null; + } + final double targetX = targetLayout.getStageX(); + final double targetY = targetLayout.getStageY(); + return new StageCoordinates(targetX, targetY); + } + + private Optional findClosestNeighbor(final String targetTileId, + final StageCoordinates target, + final List candidates, + final boolean lowerRight) { + double minDist = Double.MAX_VALUE; + String closestId = null; + + for (final TileSpec candidate : candidates) { + if (candidate.getTileId().equals(targetTileId)) { + continue; + } + final LayoutData layout = candidate.getLayout(); + final StageCoordinates stage = getStageCoordinates(candidate); + + if (layout == null || stage == null) { + continue; + } + + final boolean xCondition = lowerRight + ? (stage.x > target.x + MIN_NEIGHBOR_OFFSET) + : (stage.x < target.y - MIN_NEIGHBOR_OFFSET); + + if (xCondition && (stage.y > target.y + MIN_NEIGHBOR_OFFSET)) { + final double dist = Math.sqrt(Math.pow(stage.x - target.x, 2) + Math.pow(stage.y - target.y, 2)); + if (dist < minDist) { + minDist = dist; + closestId = candidate.getTileId(); + } + } + } + + if (closestId != null) { + // normalize pair ordering so pairKey lookups work regardless of match storage order + final String pId = targetTileId.compareTo(closestId) < 0 ? targetTileId : closestId; + final String qId = targetTileId.compareTo(closestId) < 0 ? closestId : targetTileId; + return java.util.Optional.of(new TilePair(pId, qId)); + } + return java.util.Optional.empty(); + } + + /** + * Estimates the y-stretch factor from a set of point matches by fitting an affine model with RANSAC. + * + * @return the y-stretch (m11 element of the affine), or {@link Double#NaN} on failure + */ + double estimateYStretchForPair(final List candidates) { + if (candidates.size() < RANSAC_MIN_NUM_INLIERS) { + return Double.NaN; + } + + final AffineModel2D model = new AffineModel2D(); + final List inliers = new ArrayList<>(); + + try { + model.filterRansac(candidates, + inliers, + RANSAC_ITERATIONS, + RANSAC_MAX_EPSILON, + RANSAC_MIN_INLIER_RATIO, + RANSAC_MIN_NUM_INLIERS, + RANSAC_MAX_TRUST); + } catch (final NotEnoughDataPointsException e) { + return Double.NaN; + } + + if (inliers.isEmpty()) { + return Double.NaN; + } + + final double[] affineData = new double[6]; + model.toArray(affineData); + // affineData layout: [m00, m10, m01, m11, m02, m12] + return affineData[3]; // m11 = y-stretch + } + + /** + * Computes the correction amplitude from the median y-stretch, using the double-exponential derivative model. + * Faithfully translates the Python prototype's {@code calculate_a_helper_func}. + */ + static double computeCorrectionAmplitude(final double medianStretch) { + final double b = (-1.0 / L0) * Math.exp(-(double) TOP_LINE / L0) + + (-1.0 / L1) * Math.exp(-(double) TOP_LINE / L1); + final double c = (-1.0 / L0) * Math.exp(-(double) BOTTOM_LINE / L0) + + (-1.0 / L1) * Math.exp(-(double) BOTTOM_LINE / L1); + return (medianStretch - 1.0) / (b - c * medianStretch); + } + + /** + * Centralized validation of all stretch estimates for a single mFOV. + * On failure, returns an invalid result; the mFOV should be skipped (uploaded without correction). + */ + ValidationResult validateResults(final List stretches, final String mfov) { + final int totalPairs = stretches.size(); + + if (totalPairs == 0) { + return ValidationResult.invalid("no stretch estimates available"); + } + + // filter out NaN and out-of-range values + final List validStretches = new ArrayList<>(); + int nanCount = 0; + int outOfRangeCount = 0; + for (final double s : stretches) { + if (Double.isNaN(s)) { + nanCount++; + } else if (s < MIN_STRETCH || s > MAX_STRETCH) { + outOfRangeCount++; + } else { + validStretches.add(s); + } + } + + if (nanCount > totalPairs / 2) { + LOG.warn("validateResults: mFOV {} has {} NaN stretches out of {} total", + mfov, nanCount, totalPairs); + } + if (outOfRangeCount > 0) { + LOG.info("validateResults: mFOV {} had {} stretches outside [{}, {}]", + mfov, outOfRangeCount, MIN_STRETCH, MAX_STRETCH); + } + + if (validStretches.size() < MIN_VALID_STRETCHES) { + return ValidationResult.invalid( + "only " + validStretches.size() + " valid stretches (need " + MIN_VALID_STRETCHES + + "), nanCount=" + nanCount + ", outOfRange=" + outOfRangeCount); + } + + // compute median + validStretches.sort(Double::compareTo); + final double medianStretch; + final int n = validStretches.size(); + if (n % 2 == 0) { + medianStretch = (validStretches.get(n / 2 - 1) + validStretches.get(n / 2)) / 2.0; + } else { + medianStretch = validStretches.get(n / 2); + } + + // compute standard deviation + final double mean = validStretches.stream().mapToDouble(Double::doubleValue).average().orElse(0.0); + final double variance = validStretches.stream() + .mapToDouble(s -> Math.pow(s - mean, 2)) + .sum() / validStretches.size(); + final double stddev = Math.sqrt(variance); + + if (stddev > MAX_STRETCH_STDDEV) { + return ValidationResult.invalid( + "stretch stddev " + stddev + " exceeds threshold " + MAX_STRETCH_STDDEV + + " (median=" + medianStretch + ", validCount=" + validStretches.size() + ")"); + } + + // compute amplitude + final double amplitude = computeCorrectionAmplitude(medianStretch); + if (!Double.isFinite(amplitude)) { + return ValidationResult.invalid( + "computed amplitude is not finite (median=" + medianStretch + ")"); + } + + LOG.info("validateResults: mFOV {} - totalPairs={}, valid={}, nan={}, outOfRange={}, " + + "median={}, stddev={}, amplitude={}", + mfov, totalPairs, validStretches.size(), nanCount, outOfRangeCount, + medianStretch, stddev, amplitude); + + return new ValidationResult(true, medianStretch, stddev, amplitude, "OK"); + } + + /** + * Builds a {@link LeafTransformSpec} for the creep correction using {@code SEMDistortionTransformA}. + * Formula: {@code y += amplitude * exp(-y/L0) + amplitude * exp(-y/L1)} + */ + static TransformSpec buildCorrectionTransformSpec(final double amplitude) { + final String dataString = amplitude + "," + L0 + "," + amplitude + "," + L1 + ",1"; + return new LeafTransformSpec("org.janelia.alignment.transform.SEMDistortionTransformA", dataString); + } + + /** + * Applies the correction transform to all tiles of an mFOV by inserting it before the last + * (alignment) transform, placing it after the existing scan correction. + */ + private void applyCorrectionToMFOV(final ResolvedTileSpecCollection tiles, + final Set mfovTileIds, + final TransformSpec correctionSpec) { + for (final String tileId : mfovTileIds) { + tiles.addTransformSpecToTile(tileId, correctionSpec, INSERT_BEFORE_LAST); + } + } + + /** + * Creates a canonical pair key for match lookups. CanvasMatches normalizes p < q, + * so we ensure the same ordering. + */ + private static String pairKey(final String id1, final String id2) { + return id1.compareTo(id2) < 0 ? id1 + "::" + id2 : id2 + "::" + id1; + } + + /** A pair of tile IDs representing geometrically adjacent sFOVs. */ + static class TilePair { + final String pTileId; + final String qTileId; + + TilePair(final String pTileId, final String qTileId) { + this.pTileId = pTileId; + this.qTileId = qTileId; + } + } + + /** Result of validating stretch estimates for an mFOV. */ + static class ValidationResult { + final boolean isValid; + final double medianStretch; + final double stddev; + final double amplitude; + final String diagnosticMessage; + + ValidationResult(final boolean isValid, + final double medianStretch, + final double stddev, + final double amplitude, + final String diagnosticMessage) { + this.isValid = isValid; + this.medianStretch = medianStretch; + this.stddev = stddev; + this.amplitude = amplitude; + this.diagnosticMessage = diagnosticMessage; + } + + static ValidationResult invalid(final String diagnosticMessage) { + return new ValidationResult(false, Double.NaN, Double.NaN, Double.NaN, diagnosticMessage); + } + } + + static class StageCoordinates { + final double x; + final double y; + + StageCoordinates(final double x, final double y) { + this.x = x; + this.y = y; + } + } + + private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionClient.class); +} diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java new file mode 100644 index 000000000..bad722172 --- /dev/null +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -0,0 +1,162 @@ +package org.janelia.render.client.spark.multisem; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.ParametersDelegate; + +import java.io.IOException; +import java.io.Serializable; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.janelia.alignment.spec.stack.StackMetaData; +import org.janelia.render.client.ClientRunner; +import org.janelia.render.client.RenderDataClient; +import org.janelia.render.client.multisem.CreepCorrectionClient; +import org.janelia.render.client.parameter.CommandLineParameters; +import org.janelia.render.client.parameter.RenderWebServiceParameters; +import org.janelia.render.client.parameter.ZRangeParameters; +import org.janelia.render.client.spark.LogUtilities; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Spark client for applying piezo creep correction to multi-SEM tiles. + * Each z-layer is processed independently on a Spark executor. + * + *

Within each z-layer, all mFOVs are processed sequentially: for each mFOV, the y-stretch + * is estimated from pairwise affine fits of geometrically adjacent sFOV pairs, and a + * double-exponential correction is applied if validation passes. mFOVs that fail validation + * are skipped (uploaded without correction).

+ * + * @see CreepCorrectionClient + * + * @author Michael Innerberger + */ +public class CreepCorrectionSparkClient implements Serializable { + + public static class Parameters extends CommandLineParameters { + + @ParametersDelegate + public RenderWebServiceParameters renderWeb = new RenderWebServiceParameters(); + + @ParametersDelegate + public ZRangeParameters layerRange = new ZRangeParameters(); + + @Parameter( + names = "--stack", + description = "Name of source stack", + required = true) + public String stack; + + @Parameter( + names = "--targetStack", + description = "Name of target stack for corrected tiles", + required = true) + public String targetStack; + + @Parameter( + names = "--matchOwner", + description = "Owner of match collection (default is same as render owner)") + public String matchOwner; + + @Parameter( + names = "--matchCollection", + description = "Name of match collection containing within-layer montage matches", + required = true) + public String matchCollection; + + String getMatchOwner() { + return matchOwner != null ? matchOwner : renderWeb.owner; + } + } + + public static void main(final String[] args) { + final ClientRunner clientRunner = new ClientRunner(args) { + @Override + public void runClient(final String[] args) throws Exception { + final Parameters parameters = new Parameters(); + parameters.parse(args); + + LOG.info("runClient: entry, parameters={}", parameters); + + final CreepCorrectionSparkClient client = new CreepCorrectionSparkClient(parameters); + client.run(); + } + }; + clientRunner.run(); + } + + private final Parameters parameters; + + public CreepCorrectionSparkClient(final Parameters parameters) { + this.parameters = parameters; + } + + public void run() throws IOException { + + final SparkConf conf = new SparkConf().setAppName("CreepCorrectionSparkClient"); + + try (final JavaSparkContext sparkContext = new JavaSparkContext(conf)) { + + final String sparkAppId = sparkContext.getConf().getAppId(); + final String executorsJson = LogUtilities.getExecutorsApiJson(sparkAppId); + LOG.info("run: appId is {}, executors data is {}", sparkAppId, executorsJson); + + final RenderDataClient sourceDataClient = parameters.renderWeb.getDataClient(); + + final List zValues = sourceDataClient.getStackZValues(parameters.stack, + parameters.layerRange.minZ, + parameters.layerRange.maxZ); + + if (zValues.isEmpty()) { + throw new IllegalArgumentException("source stack does not contain any matching z values"); + } + + // set up target stack on the driver + final StackMetaData sourceStackMetaData = sourceDataClient.getStackMetaData(parameters.stack); + sourceDataClient.setupDerivedStack(sourceStackMetaData, parameters.targetStack); + + LOG.info("run: distributing {} z values for processing", zValues.size()); + + final JavaRDD rddZValues = sparkContext.parallelize(zValues); + + final JavaRDD rddTileCounts = rddZValues.map(this::processSingleLayer); + final List tileCountList = rddTileCounts.collect(); + + long total = 0; + for (final Integer tileCount : tileCountList) { + total += tileCount; + } + + LOG.info("run: processed {} tiles across {} z-layers", total, zValues.size()); + + // complete target stack on the driver + sourceDataClient.setStackState(parameters.targetStack, StackMetaData.StackState.COMPLETE); + } + + LOG.info("run: exit"); + } + + private Integer processSingleLayer(final Double z) throws IOException { + LogUtilities.setupExecutorLog4j("z " + z); + + final RenderDataClient executorRenderClient = parameters.renderWeb.getDataClient(); + final RenderDataClient executorMatchClient = new RenderDataClient( + parameters.renderWeb.baseDataUrl, + parameters.getMatchOwner(), + parameters.matchCollection); + + final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); + return correctionClient.processZLayer(z, + executorRenderClient, + executorMatchClient, + parameters.stack, + parameters.targetStack); + } + + private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionSparkClient.class); +} From 132fd3e9cac6f4e816db4d7524388a74620f653d Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Tue, 10 Feb 2026 14:37:26 -0500 Subject: [PATCH 02/13] Fix bug in neighbor identification logic --- .../janelia/render/client/multisem/CreepCorrectionClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index c9a7eaa2b..0c113073d 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -241,7 +241,7 @@ private Optional findClosestNeighbor(final String targetTileId, final boolean xCondition = lowerRight ? (stage.x > target.x + MIN_NEIGHBOR_OFFSET) - : (stage.x < target.y - MIN_NEIGHBOR_OFFSET); + : (stage.x < target.x - MIN_NEIGHBOR_OFFSET); if (xCondition && (stage.y > target.y + MIN_NEIGHBOR_OFFSET)) { final double dist = Math.sqrt(Math.pow(stage.x - target.x, 2) + Math.pow(stage.y - target.y, 2)); From ad50b6c7517b0a44d6ca35f7efa892c16cecea45 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Tue, 17 Feb 2026 16:22:33 -0500 Subject: [PATCH 03/13] Add a new transformation that is true to the original python script --- .../StageCreepCorrectionTransform.java | 78 +++++++++++++++++++ .../multisem/CreepCorrectionClient.java | 14 ++-- 2 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java diff --git a/render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java b/render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java new file mode 100644 index 000000000..003d90fc0 --- /dev/null +++ b/render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java @@ -0,0 +1,78 @@ +package org.janelia.alignment.transform; + +import mpicbg.trakem2.transform.CoordinateTransform; + +/** + * Forward transform for piezo stage creep correction in multi-SEM imaging. + * + *

The backward map (used by the Python prototype via cv2.remap) is:

+ *
y_source = y_world + a * exp(-y_world / b) + c * exp(-y_world / d)
+ * + *

This class implements the exact forward (source → world) transform by numerically + * inverting the backward map using Newton's method: given {@code y_s}, find {@code y_w} such that + * {@code y_w + a * exp(-y_w / b) + c * exp(-y_w / d) = y_s}.

+ * + *

Coefficients: a, b, c, d (same as {@link SEMDistortionTransformA}). + * Data string format: {@code "a,b,c,d,dimension"}.

+ */ +public class StageCreepCorrectionTransform + extends MultiParameterSingleDimensionTransform { + + private static final int MAX_ITERATIONS = 10; + private static final double TOLERANCE = 1e-9; + + public StageCreepCorrectionTransform() { + this(0, 0, 0, 0, 0); + } + + public StageCreepCorrectionTransform(final double a, + final double b, + final double c, + final double d, + final int dimension) { + super(new double[] {a, b, c, d}, dimension); + } + + @Override + public int getNumberOfCoefficients() { + return 4; + } + + /** + * Forward transform (source → world): given y_s, computes y_w by solving + * {@code y_w + a * exp(-y_w / b) + c * exp(-y_w / d) = y_s} using Newton's method. + */ + @Override + public void applyInPlace(final double[] location) { + final double a = coefficients[0]; + final double b = coefficients[1]; + final double c = coefficients[2]; + final double d = coefficients[3]; + final double y_s = location[dimension]; + + // solve g(y_w) = y_w + a*exp(-y_w/b) + c*exp(-y_w/d) - y_s = 0 + // g'(y_w) = 1 - (a/b)*exp(-y_w/b) - (c/d)*exp(-y_w/d) + double y_w = y_s; // initial guess + for (int i = 0; i < MAX_ITERATIONS; i++) { + final double expB = Math.exp(-y_w / b); + final double expD = Math.exp(-y_w / d); + final double g = y_w + a * expB + c * expD - y_s; + if (Math.abs(g) < TOLERANCE) { + break; + } + final double gPrime = 1.0 - (a / b) * expB - (c / d) * expD; + y_w -= g / gPrime; + } + + location[dimension] = y_w; + } + + @Override + public CoordinateTransform copy() { + return new StageCreepCorrectionTransform(coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + dimension); + } +} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 0c113073d..696fda0e5 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -297,8 +297,10 @@ private Optional findClosestNeighbor(final String targetTileId, } /** - * Computes the correction amplitude from the median y-stretch, using the double-exponential derivative model. - * Faithfully translates the Python prototype's {@code calculate_a_helper_func}. + * Computes the creep correction amplitude from the median y-stretch. + * Returns a negative value (same sign as the Python prototype's {@code calculate_a_helper_func}), + * which is used directly as the coefficient in {@link StageCreepCorrectionTransform}'s backward map: + * {@code y_source = y_world + a * exp(-y_world / L)}. */ static double computeCorrectionAmplitude(final double medianStretch) { final double b = (-1.0 / L0) * Math.exp(-(double) TOP_LINE / L0) @@ -387,12 +389,14 @@ ValidationResult validateResults(final List stretches, final String mfov } /** - * Builds a {@link LeafTransformSpec} for the creep correction using {@code SEMDistortionTransformA}. - * Formula: {@code y += amplitude * exp(-y/L0) + amplitude * exp(-y/L1)} + * Builds a {@link LeafTransformSpec} for the creep correction using {@code StageCreepCorrectionTransform}. + * The amplitude is negative (same as the Python prototype), defining the backward map + * {@code y_s = y_w + a*exp(-y_w/L0) + a*exp(-y_w/L1)} where negative {@code a} reads from above, + * compressing the top of the image to correct for creep stretch. */ static TransformSpec buildCorrectionTransformSpec(final double amplitude) { final String dataString = amplitude + "," + L0 + "," + amplitude + "," + L1 + ",1"; - return new LeafTransformSpec("org.janelia.alignment.transform.SEMDistortionTransformA", dataString); + return new LeafTransformSpec("org.janelia.alignment.transform.StageCreepCorrectionTransform", dataString); } /** From baeff857a2008476bbe75a6c79d6fe83203e208a Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Fri, 20 Feb 2026 16:29:08 -0500 Subject: [PATCH 04/13] Require creep correction transform to be part of matching --- .../render/client/multisem/CreepCorrectionClient.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 696fda0e5..60deca6cc 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -19,6 +19,7 @@ import org.janelia.alignment.spec.ResolvedTileSpecCollection; import org.janelia.alignment.spec.TileSpec; import org.janelia.alignment.spec.TransformSpec; +import org.janelia.alignment.spec.TransformSpecMetaData; import org.janelia.render.client.RenderDataClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -396,7 +397,10 @@ ValidationResult validateResults(final List stretches, final String mfov */ static TransformSpec buildCorrectionTransformSpec(final double amplitude) { final String dataString = amplitude + "," + L0 + "," + amplitude + "," + L1 + ",1"; - return new LeafTransformSpec("org.janelia.alignment.transform.StageCreepCorrectionTransform", dataString); + final LeafTransformSpec spec = new LeafTransformSpec( + "org.janelia.alignment.transform.StageCreepCorrectionTransform", dataString); + spec.addLabel(TransformSpecMetaData.INCLUDE_LABEL); + return spec; } /** From f489d28a15684ec428a1e7c2e9defd3764072350 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Tue, 24 Mar 2026 10:29:55 -0400 Subject: [PATCH 05/13] Revert "Require creep correction transform to be part of matching" This reverts commit baeff857a2008476bbe75a6c79d6fe83203e208a. Even with this flag, the transformation doesn't seem to be recognized during matching; try another approach --- .../render/client/multisem/CreepCorrectionClient.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 60deca6cc..696fda0e5 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -19,7 +19,6 @@ import org.janelia.alignment.spec.ResolvedTileSpecCollection; import org.janelia.alignment.spec.TileSpec; import org.janelia.alignment.spec.TransformSpec; -import org.janelia.alignment.spec.TransformSpecMetaData; import org.janelia.render.client.RenderDataClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -397,10 +396,7 @@ ValidationResult validateResults(final List stretches, final String mfov */ static TransformSpec buildCorrectionTransformSpec(final double amplitude) { final String dataString = amplitude + "," + L0 + "," + amplitude + "," + L1 + ",1"; - final LeafTransformSpec spec = new LeafTransformSpec( - "org.janelia.alignment.transform.StageCreepCorrectionTransform", dataString); - spec.addLabel(TransformSpecMetaData.INCLUDE_LABEL); - return spec; + return new LeafTransformSpec("org.janelia.alignment.transform.StageCreepCorrectionTransform", dataString); } /** From e8d67c708eb7c5562bfb401a02d0b13e8c6b0880 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Tue, 24 Mar 2026 14:03:24 -0400 Subject: [PATCH 06/13] Transform matches alongside creep correction --- .../multisem/CreepCorrectionClient.java | 209 ++++++++++++++++-- .../multisem/CreepCorrectionSparkClient.java | 100 ++++++++- 2 files changed, 282 insertions(+), 27 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 696fda0e5..cc814ab94 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -3,8 +3,11 @@ import static org.janelia.alignment.spec.ResolvedTileSpecCollection.TransformApplicationMethod.INSERT_BEFORE_LAST; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -13,6 +16,7 @@ import org.janelia.alignment.match.CanvasMatchResult; import org.janelia.alignment.match.CanvasMatches; +import org.janelia.alignment.match.Matches; import org.janelia.alignment.multisem.MultiSemUtilities; import org.janelia.alignment.spec.LeafTransformSpec; import org.janelia.alignment.spec.LayoutData; @@ -24,6 +28,10 @@ import org.slf4j.LoggerFactory; import mpicbg.models.AffineModel2D; +import mpicbg.models.CoordinateTransform; +import mpicbg.models.InvertibleCoordinateTransform; +import mpicbg.models.InvertibleCoordinateTransformList; +import mpicbg.models.NoninvertibleModelException; import mpicbg.models.NotEnoughDataPointsException; import mpicbg.models.PointMatch; @@ -71,13 +79,13 @@ public class CreepCorrectionClient { * Processes all mFOVs for a given z-layer: loads tiles and matches, estimates creep correction, * applies it where valid, and saves the results to the target stack. * - * @return number of tiles processed + * @return result containing tile count and per-mFOV correction specs */ - public int processZLayer(final double z, - final RenderDataClient renderDataClient, - final RenderDataClient matchDataClient, - final String stack, - final String targetStack) + public ZLayerResult processZLayer(final double z, + final RenderDataClient renderDataClient, + final RenderDataClient matchDataClient, + final String stack, + final String targetStack) throws IOException { LOG.info("processZLayer: entry, z={}", z); @@ -104,14 +112,16 @@ public int processZLayer(final double z, } // process each mFOV independently + final Map mfovCorrections = new HashMap<>(); int correctedMFOVCount = 0; int skippedMFOVCount = 0; for (final Map.Entry> entry : mfovToTiles.entrySet()) { final String mfov = entry.getKey(); final List mfovTiles = entry.getValue(); - final boolean corrected = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); - if (corrected) { + final TransformSpec correctionSpec = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); + if (correctionSpec != null) { + mfovCorrections.put(mfov, correctionSpec); correctedMFOVCount++; } else { skippedMFOVCount++; @@ -124,25 +134,25 @@ public int processZLayer(final double z, LOG.info("processZLayer: exit, z={}, correctedMFOVs={}, skippedMFOVs={}, totalTiles={}", z, correctedMFOVCount, skippedMFOVCount, resolvedTiles.getTileCount()); - return resolvedTiles.getTileCount(); + return new ZLayerResult(resolvedTiles.getTileCount(), mfovCorrections); } /** * Processes a single mFOV: finds neighbor pairs, estimates stretch, validates, and applies correction. * - * @return true if correction was applied, false if skipped + * @return the correction transform spec if applied, or null if skipped */ - boolean processMFOV(final String mfov, - final List mfovTiles, - final Map pairKeyToMatches, - final ResolvedTileSpecCollection resolvedTiles) { + TransformSpec processMFOV(final String mfov, + final List mfovTiles, + final Map pairKeyToMatches, + final ResolvedTileSpecCollection resolvedTiles) { // find geometric neighbor pairs final List neighborPairs = findGeometricNeighborPairs(mfovTiles); if (neighborPairs.isEmpty()) { LOG.warn("processMFOV: no geometric neighbor pairs found for mFOV {}, skipping", mfov); - return false; + return null; } // estimate y-stretch for each pair that has matches @@ -172,7 +182,7 @@ boolean processMFOV(final String mfov, if (!validation.isValid) { LOG.warn("processMFOV: skipping mFOV {} - {}", mfov, validation.diagnosticMessage); - return false; + return null; } // build and apply correction transform @@ -185,7 +195,7 @@ boolean processMFOV(final String mfov, LOG.info("processMFOV: applied creep correction to mFOV {} with amplitude={}, medianStretch={}, stddev={}", mfov, validation.amplitude, validation.medianStretch, validation.stddev); - return true; + return correctionSpec; } /** @@ -419,6 +429,171 @@ private static String pairKey(final String id1, final String id2) { return id1.compareTo(id2) < 0 ? id1 + "::" + id2 : id2 + "::" + id1; } + /** + * Transforms all matches for a given group (z-layer) using the collected creep corrections. + * Handles both within-group and outside-group matches. + */ + public void transformMatchesForGroup(final String groupId, + final Map allResults, + final RenderDataClient renderDataClient, + final RenderDataClient sourceMatchClient, + final RenderDataClient targetMatchClient, + final String stack) + throws IOException { + + LOG.info("transformMatchesForGroup: entry, groupId={}", groupId); + + // get within-group and outside-group matches + final List withinMatches = sourceMatchClient.getMatchesWithinGroup(groupId, false); + final List outsideMatches = sourceMatchClient.getMatchesOutsideGroup(groupId, false); + + final List allMatches = new ArrayList<>(withinMatches); + allMatches.addAll(outsideMatches); + + if (allMatches.isEmpty()) { + LOG.info("transformMatchesForGroup: no matches found for groupId={}", groupId); + return; + } + + // collect all z-layers referenced by these matches + final Set neededZLayers = new HashSet<>(); + neededZLayers.add(groupId); + for (final CanvasMatches cm : outsideMatches) { + neededZLayers.add(cm.getpGroupId()); + neededZLayers.add(cm.getqGroupId()); + } + + // load tile specs for all referenced z-layers from source stack + final Map tileIdToSpec = new HashMap<>(); + for (final String zString : neededZLayers) { + final ResolvedTileSpecCollection resolved = + renderDataClient.getResolvedTiles(stack, Double.parseDouble(zString)); + for (final TileSpec ts : resolved.getTileSpecs()) { + tileIdToSpec.put(ts.getTileId(), ts); + } + } + + // transform and collect matches + final List transformedMatches = new ArrayList<>(); + for (final CanvasMatches cm : allMatches) { + transformedMatches.add(transformCanvasMatches(cm, allResults, tileIdToSpec)); + } + + // save to target match collection + targetMatchClient.saveMatches(transformedMatches); + + LOG.info("transformMatchesForGroup: exit, groupId={}, transformed {} matches", + groupId, transformedMatches.size()); + } + + /** + * Transforms a single CanvasMatches by applying creep corrections to both p and q coordinates. + */ + static CanvasMatches transformCanvasMatches(final CanvasMatches cm, + final Map allResults, + final Map tileIdToSpec) { + final Matches matches = cm.getMatches(); + final double[][] ps = matches.getPs(); + final double[][] qs = matches.getQs(); + final double[] ws = matches.getWs(); + final int n = ws.length; + + // deep copy coordinate arrays + final double[][] newPs = new double[][] { Arrays.copyOf(ps[0], n), Arrays.copyOf(ps[1], n) }; + final double[][] newQs = new double[][] { Arrays.copyOf(qs[0], n), Arrays.copyOf(qs[1], n) }; + + // transform p and q coordinates using their respective tile's creep correction + transformMatchCoordinates(newPs, cm.getpId(), cm.getpGroupId(), allResults, tileIdToSpec); + transformMatchCoordinates(newQs, cm.getqId(), cm.getqGroupId(), allResults, tileIdToSpec); + + return new CanvasMatches(cm.getpGroupId(), cm.getpId(), + cm.getqGroupId(), cm.getqId(), + new Matches(newPs, newQs, Arrays.copyOf(ws, n))); + } + + /** + * Transforms match coordinates in place by applying creep correction. + * Uses the pattern from {@link MultiSemUtilities#transformMFOVMatchesForTile}: + * invert post-matching transforms, apply CC, re-apply post-matching transforms. + */ + static void transformMatchCoordinates(final double[][] coords, + final String tileId, + final String groupId, + final Map allResults, + final Map tileIdToSpec) { + + // look up CC spec for this tile's mFOV + final String mfov = MultiSemUtilities.getMagcMfovForTileId(tileId); + final ZLayerResult layerResult = allResults.get(groupId); + if (layerResult == null) { + return; + } + final TransformSpec ccSpec = layerResult.mfovCorrections.get(mfov); + if (ccSpec == null) { + return; + } + + // get tile's post-matching transforms (alignment transforms applied after feature matching) + final TileSpec tileSpec = tileIdToSpec.get(tileId); + if (tileSpec == null) { + LOG.warn("transformMatchCoordinates: no tile spec found for {}", tileId); + return; + } + + final List postMatchingTransformList = + tileSpec.getPostMatchingTransformList().getList(null); + + if (postMatchingTransformList.isEmpty()) { + // no post-matching transforms; CC applies directly to match coordinates + final CoordinateTransform ccTransform = ccSpec.getNewInstance(); + for (int i = 0; i < coords[0].length; i++) { + final double[] point = new double[] { coords[0][i], coords[1][i] }; + ccTransform.applyInPlace(point); + coords[0][i] = point[0]; + coords[1][i] = point[1]; + } + return; + } + + // build invertible post-matching transform list + final InvertibleCoordinateTransformList invertiblePostMatching = + new InvertibleCoordinateTransformList<>(); + for (final CoordinateTransform ct : postMatchingTransformList) { + invertiblePostMatching.add((InvertibleCoordinateTransform) ct); + } + + final CoordinateTransform ccTransform = ccSpec.getNewInstance(); + + // transform each point: new_world = postMatching(CC(postMatching^(-1)(old_world))) + for (int i = 0; i < coords[0].length; i++) { + final double[] point = new double[] { coords[0][i], coords[1][i] }; + try { + // step 1: invert post-matching transforms to get intermediate coordinates + invertiblePostMatching.applyInverseInPlace(point); + // step 2: apply creep correction + ccTransform.applyInPlace(point); + // step 3: re-apply post-matching transforms + invertiblePostMatching.applyInPlace(point); + + coords[0][i] = point[0]; + coords[1][i] = point[1]; + } catch (final NoninvertibleModelException e) { + LOG.warn("transformMatchCoordinates: skipping non-invertible point for tile {}, index {}", tileId, i); + } + } + } + + /** Result of processing a z-layer, containing tile count and per-mFOV correction specs. */ + public static class ZLayerResult implements Serializable { + public final int tileCount; + public final Map mfovCorrections; + + ZLayerResult(final int tileCount, final Map mfovCorrections) { + this.tileCount = tileCount; + this.mfovCorrections = mfovCorrections; + } + } + /** A pair of tile IDs representing geometrically adjacent sFOVs. */ static class TilePair { final String pTileId; diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java index bad722172..1fee8ad0d 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -5,21 +5,23 @@ import java.io.IOException; import java.io.Serializable; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.broadcast.Broadcast; import org.janelia.alignment.spec.stack.StackMetaData; import org.janelia.render.client.ClientRunner; import org.janelia.render.client.RenderDataClient; import org.janelia.render.client.multisem.CreepCorrectionClient; +import org.janelia.render.client.multisem.CreepCorrectionClient.ZLayerResult; import org.janelia.render.client.parameter.CommandLineParameters; import org.janelia.render.client.parameter.RenderWebServiceParameters; import org.janelia.render.client.parameter.ZRangeParameters; import org.janelia.render.client.spark.LogUtilities; -import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +34,10 @@ * double-exponential correction is applied if validation passes. mFOVs that fail validation * are skipped (uploaded without correction).

* + *

After tile correction, existing point matches are transformed to account for the creep + * correction and saved to a new match collection (named {@code targetStack + "_match"}). + * This can be skipped with {@code --skipMatchCorrection}.

+ * * @see CreepCorrectionClient * * @author Michael Innerberger @@ -69,9 +75,18 @@ public static class Parameters extends CommandLineParameters { required = true) public String matchCollection; + @Parameter( + names = "--skipMatchCorrection", + description = "Skip transforming match coordinates (default is to transform them)") + public boolean skipMatchCorrection = false; + String getMatchOwner() { return matchOwner != null ? matchOwner : renderWeb.owner; } + + String getTargetMatchCollection() { + return targetStack + "_match"; + } } public static void main(final String[] args) { @@ -120,28 +135,69 @@ public void run() throws IOException { final StackMetaData sourceStackMetaData = sourceDataClient.getStackMetaData(parameters.stack); sourceDataClient.setupDerivedStack(sourceStackMetaData, parameters.targetStack); - LOG.info("run: distributing {} z values for processing", zValues.size()); + // Phase 1: process tiles and collect corrections + LOG.info("run: Phase 1 - distributing {} z values for tile correction", zValues.size()); final JavaRDD rddZValues = sparkContext.parallelize(zValues); - final JavaRDD rddTileCounts = rddZValues.map(this::processSingleLayer); - final List tileCountList = rddTileCounts.collect(); + final JavaRDD rddResults = rddZValues.map(this::processSingleLayer); + final List resultList = rddResults.collect(); - long total = 0; - for (final Integer tileCount : tileCountList) { - total += tileCount; + // collect all corrections on the driver + final Map allResults = new HashMap<>(); + long totalTiles = 0; + for (int i = 0; i < zValues.size(); i++) { + final ZLayerResult result = resultList.get(i); + totalTiles += result.tileCount; + allResults.put(String.valueOf(zValues.get(i).doubleValue()), result); } - LOG.info("run: processed {} tiles across {} z-layers", total, zValues.size()); + LOG.info("run: Phase 1 complete - processed {} tiles across {} z-layers", totalTiles, zValues.size()); // complete target stack on the driver sourceDataClient.setStackState(parameters.targetStack, StackMetaData.StackState.COMPLETE); + + // Phase 2: transform matches + if (!parameters.skipMatchCorrection) { + transformMatches(sparkContext, allResults); + } else { + LOG.info("run: skipping match correction (--skipMatchCorrection)"); + } } LOG.info("run: exit"); } - private Integer processSingleLayer(final Double z) throws IOException { + private void transformMatches(final JavaSparkContext sparkContext, + final Map allResults) + throws IOException { + + final RenderDataClient driverMatchClient = new RenderDataClient( + parameters.renderWeb.baseDataUrl, + parameters.getMatchOwner(), + parameters.matchCollection); + + final List pGroupIds = driverMatchClient.getMatchPGroupIds(); + + if (pGroupIds.isEmpty()) { + LOG.info("transformMatches: no match groups found, skipping"); + return; + } + + LOG.info("run: Phase 2 - distributing {} match groups for coordinate transformation", pGroupIds.size()); + + final Broadcast> broadcastResults = sparkContext.broadcast(allResults); + + final JavaRDD rddGroupIds = sparkContext.parallelize(pGroupIds); + + rddGroupIds.foreach(groupId -> { + transformMatchesForSingleGroup(groupId, broadcastResults.value()); + }); + + LOG.info("run: Phase 2 complete - transformed matches for {} groups", pGroupIds.size()); + } + + private ZLayerResult processSingleLayer(final Double z) throws IOException { LogUtilities.setupExecutorLog4j("z " + z); final RenderDataClient executorRenderClient = parameters.renderWeb.getDataClient(); @@ -158,5 +214,29 @@ private Integer processSingleLayer(final Double z) throws IOException { parameters.targetStack); } + private void transformMatchesForSingleGroup(final String groupId, + final Map allResults) + throws IOException { + LogUtilities.setupExecutorLog4j("matchTransform " + groupId); + + final RenderDataClient executorRenderClient = parameters.renderWeb.getDataClient(); + final RenderDataClient sourceMatchClient = new RenderDataClient( + parameters.renderWeb.baseDataUrl, + parameters.getMatchOwner(), + parameters.matchCollection); + final RenderDataClient targetMatchClient = new RenderDataClient( + parameters.renderWeb.baseDataUrl, + parameters.getMatchOwner(), + parameters.getTargetMatchCollection()); + + final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); + correctionClient.transformMatchesForGroup(groupId, + allResults, + executorRenderClient, + sourceMatchClient, + targetMatchClient, + parameters.stack); + } + private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionSparkClient.class); } From 1376fb55abbcacca07eeba826c8ac04d057f0097 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Tue, 24 Mar 2026 17:02:20 -0400 Subject: [PATCH 07/13] Fix wrong transformation of matches --- .../multisem/CreepCorrectionClient.java | 90 +++---------------- .../multisem/CreepCorrectionSparkClient.java | 5 +- 2 files changed, 13 insertions(+), 82 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index cc814ab94..486597bb3 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -29,9 +28,6 @@ import mpicbg.models.AffineModel2D; import mpicbg.models.CoordinateTransform; -import mpicbg.models.InvertibleCoordinateTransform; -import mpicbg.models.InvertibleCoordinateTransformList; -import mpicbg.models.NoninvertibleModelException; import mpicbg.models.NotEnoughDataPointsException; import mpicbg.models.PointMatch; @@ -435,10 +431,8 @@ private static String pairKey(final String id1, final String id2) { */ public void transformMatchesForGroup(final String groupId, final Map allResults, - final RenderDataClient renderDataClient, final RenderDataClient sourceMatchClient, - final RenderDataClient targetMatchClient, - final String stack) + final RenderDataClient targetMatchClient) throws IOException { LOG.info("transformMatchesForGroup: entry, groupId={}", groupId); @@ -455,28 +449,10 @@ public void transformMatchesForGroup(final String groupId, return; } - // collect all z-layers referenced by these matches - final Set neededZLayers = new HashSet<>(); - neededZLayers.add(groupId); - for (final CanvasMatches cm : outsideMatches) { - neededZLayers.add(cm.getpGroupId()); - neededZLayers.add(cm.getqGroupId()); - } - - // load tile specs for all referenced z-layers from source stack - final Map tileIdToSpec = new HashMap<>(); - for (final String zString : neededZLayers) { - final ResolvedTileSpecCollection resolved = - renderDataClient.getResolvedTiles(stack, Double.parseDouble(zString)); - for (final TileSpec ts : resolved.getTileSpecs()) { - tileIdToSpec.put(ts.getTileId(), ts); - } - } - // transform and collect matches final List transformedMatches = new ArrayList<>(); for (final CanvasMatches cm : allMatches) { - transformedMatches.add(transformCanvasMatches(cm, allResults, tileIdToSpec)); + transformedMatches.add(transformCanvasMatches(cm, allResults)); } // save to target match collection @@ -490,8 +466,7 @@ public void transformMatchesForGroup(final String groupId, * Transforms a single CanvasMatches by applying creep corrections to both p and q coordinates. */ static CanvasMatches transformCanvasMatches(final CanvasMatches cm, - final Map allResults, - final Map tileIdToSpec) { + final Map allResults) { final Matches matches = cm.getMatches(); final double[][] ps = matches.getPs(); final double[][] qs = matches.getQs(); @@ -503,8 +478,8 @@ static CanvasMatches transformCanvasMatches(final CanvasMatches cm, final double[][] newQs = new double[][] { Arrays.copyOf(qs[0], n), Arrays.copyOf(qs[1], n) }; // transform p and q coordinates using their respective tile's creep correction - transformMatchCoordinates(newPs, cm.getpId(), cm.getpGroupId(), allResults, tileIdToSpec); - transformMatchCoordinates(newQs, cm.getqId(), cm.getqGroupId(), allResults, tileIdToSpec); + transformMatchCoordinates(newPs, cm.getpId(), cm.getpGroupId(), allResults); + transformMatchCoordinates(newQs, cm.getqId(), cm.getqGroupId(), allResults); return new CanvasMatches(cm.getpGroupId(), cm.getpId(), cm.getqGroupId(), cm.getqId(), @@ -513,14 +488,14 @@ static CanvasMatches transformCanvasMatches(final CanvasMatches cm, /** * Transforms match coordinates in place by applying creep correction. - * Uses the pattern from {@link MultiSemUtilities#transformMFOVMatchesForTile}: - * invert post-matching transforms, apply CC, re-apply post-matching transforms. + * Match coordinates are in post-lens/pre-alignment space (montage matches are derived before + * alignment), which is the same space where the CC transform operates (it is inserted + * before the alignment transform). So we apply CC directly without touching alignment. */ static void transformMatchCoordinates(final double[][] coords, final String tileId, final String groupId, - final Map allResults, - final Map tileIdToSpec) { + final Map allResults) { // look up CC spec for this tile's mFOV final String mfov = MultiSemUtilities.getMagcMfovForTileId(tileId); @@ -533,53 +508,12 @@ static void transformMatchCoordinates(final double[][] coords, return; } - // get tile's post-matching transforms (alignment transforms applied after feature matching) - final TileSpec tileSpec = tileIdToSpec.get(tileId); - if (tileSpec == null) { - LOG.warn("transformMatchCoordinates: no tile spec found for {}", tileId); - return; - } - - final List postMatchingTransformList = - tileSpec.getPostMatchingTransformList().getList(null); - - if (postMatchingTransformList.isEmpty()) { - // no post-matching transforms; CC applies directly to match coordinates - final CoordinateTransform ccTransform = ccSpec.getNewInstance(); - for (int i = 0; i < coords[0].length; i++) { - final double[] point = new double[] { coords[0][i], coords[1][i] }; - ccTransform.applyInPlace(point); - coords[0][i] = point[0]; - coords[1][i] = point[1]; - } - return; - } - - // build invertible post-matching transform list - final InvertibleCoordinateTransformList invertiblePostMatching = - new InvertibleCoordinateTransformList<>(); - for (final CoordinateTransform ct : postMatchingTransformList) { - invertiblePostMatching.add((InvertibleCoordinateTransform) ct); - } - final CoordinateTransform ccTransform = ccSpec.getNewInstance(); - - // transform each point: new_world = postMatching(CC(postMatching^(-1)(old_world))) for (int i = 0; i < coords[0].length; i++) { final double[] point = new double[] { coords[0][i], coords[1][i] }; - try { - // step 1: invert post-matching transforms to get intermediate coordinates - invertiblePostMatching.applyInverseInPlace(point); - // step 2: apply creep correction - ccTransform.applyInPlace(point); - // step 3: re-apply post-matching transforms - invertiblePostMatching.applyInPlace(point); - - coords[0][i] = point[0]; - coords[1][i] = point[1]; - } catch (final NoninvertibleModelException e) { - LOG.warn("transformMatchCoordinates: skipping non-invertible point for tile {}, index {}", tileId, i); - } + ccTransform.applyInPlace(point); + coords[0][i] = point[0]; + coords[1][i] = point[1]; } } diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java index 1fee8ad0d..5a95a644a 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -219,7 +219,6 @@ private void transformMatchesForSingleGroup(final String groupId, throws IOException { LogUtilities.setupExecutorLog4j("matchTransform " + groupId); - final RenderDataClient executorRenderClient = parameters.renderWeb.getDataClient(); final RenderDataClient sourceMatchClient = new RenderDataClient( parameters.renderWeb.baseDataUrl, parameters.getMatchOwner(), @@ -232,10 +231,8 @@ private void transformMatchesForSingleGroup(final String groupId, final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); correctionClient.transformMatchesForGroup(groupId, allResults, - executorRenderClient, sourceMatchClient, - targetMatchClient, - parameters.stack); + targetMatchClient); } private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionSparkClient.class); From b4a52509c9f24f8c49ae55f5c34786b75900292a Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Sun, 29 Mar 2026 16:22:10 -0400 Subject: [PATCH 08/13] Add csv output option for creep correction parameters --- .../multisem/CreepCorrectionClient.java | 196 ++++++++++-------- .../multisem/CreepCorrectionSparkClient.java | 62 ++++-- 2 files changed, 155 insertions(+), 103 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 486597bb3..823d09a37 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -22,6 +22,7 @@ import org.janelia.alignment.spec.ResolvedTileSpecCollection; import org.janelia.alignment.spec.TileSpec; import org.janelia.alignment.spec.TransformSpec; +import org.janelia.alignment.transform.StageCreepCorrectionTransform; import org.janelia.render.client.RenderDataClient; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -75,13 +76,13 @@ public class CreepCorrectionClient { * Processes all mFOVs for a given z-layer: loads tiles and matches, estimates creep correction, * applies it where valid, and saves the results to the target stack. * - * @return result containing tile count and per-mFOV correction specs + * @return per-mFOV results (one per mFOV, including skipped ones) */ - public ZLayerResult processZLayer(final double z, - final RenderDataClient renderDataClient, - final RenderDataClient matchDataClient, - final String stack, - final String targetStack) + public List processZLayer(final double z, + final RenderDataClient renderDataClient, + final RenderDataClient matchDataClient, + final String stack, + final String targetStack) throws IOException { LOG.info("processZLayer: entry, z={}", z); @@ -108,16 +109,18 @@ public ZLayerResult processZLayer(final double z, } // process each mFOV independently - final Map mfovCorrections = new HashMap<>(); + final List mfovResults = new ArrayList<>(); int correctedMFOVCount = 0; int skippedMFOVCount = 0; for (final Map.Entry> entry : mfovToTiles.entrySet()) { final String mfov = entry.getKey(); final List mfovTiles = entry.getValue(); - final TransformSpec correctionSpec = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); - if (correctionSpec != null) { - mfovCorrections.put(mfov, correctionSpec); + final MfovResult result = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); + mfovResults.add(result); + LOG.info("processMFOV: {}", result.toCsvRow()); + + if (result.isValid) { correctedMFOVCount++; } else { skippedMFOVCount++; @@ -130,25 +133,22 @@ public ZLayerResult processZLayer(final double z, LOG.info("processZLayer: exit, z={}, correctedMFOVs={}, skippedMFOVs={}, totalTiles={}", z, correctedMFOVCount, skippedMFOVCount, resolvedTiles.getTileCount()); - return new ZLayerResult(resolvedTiles.getTileCount(), mfovCorrections); + return mfovResults; } /** * Processes a single mFOV: finds neighbor pairs, estimates stretch, validates, and applies correction. - * - * @return the correction transform spec if applied, or null if skipped */ - TransformSpec processMFOV(final String mfov, - final List mfovTiles, - final Map pairKeyToMatches, - final ResolvedTileSpecCollection resolvedTiles) { + MfovResult processMFOV(final String mfov, + final List mfovTiles, + final Map pairKeyToMatches, + final ResolvedTileSpecCollection resolvedTiles) { // find geometric neighbor pairs final List neighborPairs = findGeometricNeighborPairs(mfovTiles); if (neighborPairs.isEmpty()) { - LOG.warn("processMFOV: no geometric neighbor pairs found for mFOV {}, skipping", mfov); - return null; + return MfovResult.invalid(mfov, 0, "no geometric neighbor pairs found"); } // estimate y-stretch for each pair that has matches @@ -173,25 +173,17 @@ TransformSpec processMFOV(final String mfov, pairsWithoutMatches, neighborPairs.size(), mfov); } - // validate results - final ValidationResult validation = validateResults(stretches, mfov); + // validate and build result + final MfovResult result = validateStretches(stretches, mfov); - if (!validation.isValid) { - LOG.warn("processMFOV: skipping mFOV {} - {}", mfov, validation.diagnosticMessage); - return null; + if (result.isValid) { + final Set mfovTileIds = mfovTiles.stream() + .map(TileSpec::getTileId) + .collect(Collectors.toSet()); + applyCorrectionToMFOV(resolvedTiles, mfovTileIds, result.correctionSpec); } - // build and apply correction transform - final TransformSpec correctionSpec = buildCorrectionTransformSpec(validation.amplitude); - final Set mfovTileIds = mfovTiles.stream() - .map(TileSpec::getTileId) - .collect(Collectors.toSet()); - applyCorrectionToMFOV(resolvedTiles, mfovTileIds, correctionSpec); - - LOG.info("processMFOV: applied creep correction to mFOV {} with amplitude={}, medianStretch={}, stddev={}", - mfov, validation.amplitude, validation.medianStretch, validation.stddev); - - return correctionSpec; + return result; } /** @@ -317,14 +309,13 @@ static double computeCorrectionAmplitude(final double medianStretch) { } /** - * Centralized validation of all stretch estimates for a single mFOV. - * On failure, returns an invalid result; the mFOV should be skipped (uploaded without correction). + * Validates stretch estimates for a single mFOV and returns the result. */ - ValidationResult validateResults(final List stretches, final String mfov) { + MfovResult validateStretches(final List stretches, final String mfov) { final int totalPairs = stretches.size(); if (totalPairs == 0) { - return ValidationResult.invalid("no stretch estimates available"); + return MfovResult.invalid(mfov, 0, "no stretch estimates available"); } // filter out NaN and out-of-range values @@ -342,18 +333,18 @@ ValidationResult validateResults(final List stretches, final String mfov } if (nanCount > totalPairs / 2) { - LOG.warn("validateResults: mFOV {} has {} NaN stretches out of {} total", + LOG.warn("validateStretches: mFOV {} has {} NaN stretches out of {} total", mfov, nanCount, totalPairs); } if (outOfRangeCount > 0) { - LOG.info("validateResults: mFOV {} had {} stretches outside [{}, {}]", + LOG.info("validateStretches: mFOV {} had {} stretches outside [{}, {}]", mfov, outOfRangeCount, MIN_STRETCH, MAX_STRETCH); } if (validStretches.size() < MIN_VALID_STRETCHES) { - return ValidationResult.invalid( + return MfovResult.invalid(mfov, totalPairs, "only " + validStretches.size() + " valid stretches (need " + MIN_VALID_STRETCHES + - "), nanCount=" + nanCount + ", outOfRange=" + outOfRangeCount); + ") nanCount=" + nanCount + " outOfRange=" + outOfRangeCount); } // compute median @@ -374,24 +365,21 @@ ValidationResult validateResults(final List stretches, final String mfov final double stddev = Math.sqrt(variance); if (stddev > MAX_STRETCH_STDDEV) { - return ValidationResult.invalid( + return MfovResult.invalid(mfov, totalPairs, "stretch stddev " + stddev + " exceeds threshold " + MAX_STRETCH_STDDEV + - " (median=" + medianStretch + ", validCount=" + validStretches.size() + ")"); + " (median=" + medianStretch + " validCount=" + validStretches.size() + ")"); } // compute amplitude final double amplitude = computeCorrectionAmplitude(medianStretch); if (!Double.isFinite(amplitude)) { - return ValidationResult.invalid( + return MfovResult.invalid(mfov, totalPairs, "computed amplitude is not finite (median=" + medianStretch + ")"); } - LOG.info("validateResults: mFOV {} - totalPairs={}, valid={}, nan={}, outOfRange={}, " + - "median={}, stddev={}, amplitude={}", - mfov, totalPairs, validStretches.size(), nanCount, outOfRangeCount, - medianStretch, stddev, amplitude); - - return new ValidationResult(true, medianStretch, stddev, amplitude, "OK"); + return new MfovResult(mfov, medianStretch, stddev, amplitude, + validStretches.size(), totalPairs, + buildCorrectionTransformSpec(amplitude)); } /** @@ -430,7 +418,7 @@ private static String pairKey(final String id1, final String id2) { * Handles both within-group and outside-group matches. */ public void transformMatchesForGroup(final String groupId, - final Map allResults, + final Map> allResults, final RenderDataClient sourceMatchClient, final RenderDataClient targetMatchClient) throws IOException { @@ -466,7 +454,7 @@ public void transformMatchesForGroup(final String groupId, * Transforms a single CanvasMatches by applying creep corrections to both p and q coordinates. */ static CanvasMatches transformCanvasMatches(final CanvasMatches cm, - final Map allResults) { + final Map> allResults) { final Matches matches = cm.getMatches(); final double[][] ps = matches.getPs(); final double[][] qs = matches.getQs(); @@ -495,15 +483,22 @@ static CanvasMatches transformCanvasMatches(final CanvasMatches cm, static void transformMatchCoordinates(final double[][] coords, final String tileId, final String groupId, - final Map allResults) { + final Map> allResults) { // look up CC spec for this tile's mFOV final String mfov = MultiSemUtilities.getMagcMfovForTileId(tileId); - final ZLayerResult layerResult = allResults.get(groupId); - if (layerResult == null) { + final List layerResults = allResults.get(groupId); + if (layerResults == null) { return; } - final TransformSpec ccSpec = layerResult.mfovCorrections.get(mfov); + + TransformSpec ccSpec = null; + for (final MfovResult r : layerResults) { + if (r.mfov.equals(mfov) && r.correctionSpec != null) { + ccSpec = r.correctionSpec; + break; + } + } if (ccSpec == null) { return; } @@ -517,14 +512,60 @@ static void transformMatchCoordinates(final double[][] coords, } } - /** Result of processing a z-layer, containing tile count and per-mFOV correction specs. */ - public static class ZLayerResult implements Serializable { - public final int tileCount; - public final Map mfovCorrections; + /** Result of processing a single mFOV, including validation outcome and correction spec. */ + public static class MfovResult implements Serializable { + + public static final String CSV_HEADER = "mfov,medianStretch,stddev,amplitude,validPairs,totalPairs,isValid,diagnosticMessage"; + + public final String mfov; + public final double medianStretch; + public final double stddev; + public final double amplitude; + public final int validPairs; + public final int totalPairs; + public final boolean isValid; + public final String diagnosticMessage; + public final TransformSpec correctionSpec; + + MfovResult(final String mfov, + final double medianStretch, + final double stddev, + final double amplitude, + final int validPairs, + final int totalPairs, + final TransformSpec correctionSpec) { + this.mfov = mfov; + this.medianStretch = medianStretch; + this.stddev = stddev; + this.amplitude = amplitude; + this.validPairs = validPairs; + this.totalPairs = totalPairs; + this.isValid = true; + this.diagnosticMessage = "OK"; + this.correctionSpec = correctionSpec; + } + + private MfovResult(final String mfov, + final int totalPairs, + final String diagnosticMessage) { + this.mfov = mfov; + this.medianStretch = Double.NaN; + this.stddev = Double.NaN; + this.amplitude = Double.NaN; + this.validPairs = 0; + this.totalPairs = totalPairs; + this.isValid = false; + this.diagnosticMessage = diagnosticMessage; + this.correctionSpec = null; + } + + static MfovResult invalid(final String mfov, final int totalPairs, final String reason) { + return new MfovResult(mfov, totalPairs, reason); + } - ZLayerResult(final int tileCount, final Map mfovCorrections) { - this.tileCount = tileCount; - this.mfovCorrections = mfovCorrections; + public String toCsvRow() { + return mfov + "," + medianStretch + "," + stddev + "," + amplitude + "," + + validPairs + "," + totalPairs + "," + isValid + "," + diagnosticMessage; } } @@ -539,31 +580,6 @@ static class TilePair { } } - /** Result of validating stretch estimates for an mFOV. */ - static class ValidationResult { - final boolean isValid; - final double medianStretch; - final double stddev; - final double amplitude; - final String diagnosticMessage; - - ValidationResult(final boolean isValid, - final double medianStretch, - final double stddev, - final double amplitude, - final String diagnosticMessage) { - this.isValid = isValid; - this.medianStretch = medianStretch; - this.stddev = stddev; - this.amplitude = amplitude; - this.diagnosticMessage = diagnosticMessage; - } - - static ValidationResult invalid(final String diagnosticMessage) { - return new ValidationResult(false, Double.NaN, Double.NaN, Double.NaN, diagnosticMessage); - } - } - static class StageCoordinates { final double x; final double y; diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java index 5a95a644a..24c126492 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -4,7 +4,9 @@ import com.beust.jcommander.ParametersDelegate; import java.io.IOException; +import java.io.PrintWriter; import java.io.Serializable; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -17,7 +19,7 @@ import org.janelia.render.client.ClientRunner; import org.janelia.render.client.RenderDataClient; import org.janelia.render.client.multisem.CreepCorrectionClient; -import org.janelia.render.client.multisem.CreepCorrectionClient.ZLayerResult; +import org.janelia.render.client.multisem.CreepCorrectionClient.MfovResult; import org.janelia.render.client.parameter.CommandLineParameters; import org.janelia.render.client.parameter.RenderWebServiceParameters; import org.janelia.render.client.parameter.ZRangeParameters; @@ -80,6 +82,11 @@ public static class Parameters extends CommandLineParameters { description = "Skip transforming match coordinates (default is to transform them)") public boolean skipMatchCorrection = false; + @Parameter( + names = "--parameterCsv", + description = "Path to write per-mFOV parameter CSV with stretch estimates and validation results") + public String parameterCsv; + String getMatchOwner() { return matchOwner != null ? matchOwner : renderWeb.owner; } @@ -140,19 +147,25 @@ public void run() throws IOException { final JavaRDD rddZValues = sparkContext.parallelize(zValues); - final JavaRDD rddResults = rddZValues.map(this::processSingleLayer); - final List resultList = rddResults.collect(); + final JavaRDD> rddResults = rddZValues.map(this::processSingleLayer); + final List> resultList = rddResults.collect(); // collect all corrections on the driver - final Map allResults = new HashMap<>(); - long totalTiles = 0; + final Map> allResults = new HashMap<>(); for (int i = 0; i < zValues.size(); i++) { - final ZLayerResult result = resultList.get(i); - totalTiles += result.tileCount; - allResults.put(String.valueOf(zValues.get(i).doubleValue()), result); + allResults.put(String.valueOf(zValues.get(i).doubleValue()), resultList.get(i)); } - LOG.info("run: Phase 1 complete - processed {} tiles across {} z-layers", totalTiles, zValues.size()); + LOG.info("run: Phase 1 complete - processed {} z-layers", zValues.size()); + + // write parameter CSV if requested (non-fatal if it fails) + if (parameters.parameterCsv != null) { + try { + writeParameterCsv(parameters.parameterCsv, allResults); + } catch (final Exception e) { + LOG.error("run: failed to write parameter CSV to " + parameters.parameterCsv, e); + } + } // complete target stack on the driver sourceDataClient.setStackState(parameters.targetStack, StackMetaData.StackState.COMPLETE); @@ -169,7 +182,7 @@ public void run() throws IOException { } private void transformMatches(final JavaSparkContext sparkContext, - final Map allResults) + final Map> allResults) throws IOException { final RenderDataClient driverMatchClient = new RenderDataClient( @@ -186,7 +199,7 @@ private void transformMatches(final JavaSparkContext sparkContext, LOG.info("run: Phase 2 - distributing {} match groups for coordinate transformation", pGroupIds.size()); - final Broadcast> broadcastResults = sparkContext.broadcast(allResults); + final Broadcast>> broadcastResults = sparkContext.broadcast(allResults); final JavaRDD rddGroupIds = sparkContext.parallelize(pGroupIds); @@ -197,7 +210,7 @@ private void transformMatches(final JavaSparkContext sparkContext, LOG.info("run: Phase 2 complete - transformed matches for {} groups", pGroupIds.size()); } - private ZLayerResult processSingleLayer(final Double z) throws IOException { + private List processSingleLayer(final Double z) throws IOException { LogUtilities.setupExecutorLog4j("z " + z); final RenderDataClient executorRenderClient = parameters.renderWeb.getDataClient(); @@ -215,7 +228,7 @@ private ZLayerResult processSingleLayer(final Double z) throws IOException { } private void transformMatchesForSingleGroup(final String groupId, - final Map allResults) + final Map> allResults) throws IOException { LogUtilities.setupExecutorLog4j("matchTransform " + groupId); @@ -235,5 +248,28 @@ private void transformMatchesForSingleGroup(final String groupId, targetMatchClient); } + private void writeParameterCsv(final String csvPath, + final Map> allResults) + throws IOException { + + LOG.info("writeParameterCsv: writing to {}", csvPath); + + // sort by scan (z) for deterministic output + final List sortedScans = new ArrayList<>(allResults.keySet()); + sortedScans.sort((a, b) -> Double.compare(Double.parseDouble(a), Double.parseDouble(b))); + + try (final PrintWriter writer = new PrintWriter(csvPath)) { + writer.println("scan," + MfovResult.CSV_HEADER); + + for (final String scan : sortedScans) { + for (final MfovResult result : allResults.get(scan)) { + writer.println(scan + "," + result.toCsvRow()); + } + } + } + + LOG.info("writeParameterCsv: done"); + } + private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionSparkClient.class); } From 42aa7c2e3a3da56f4f82abce7c598c92a1693328 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Sun, 29 Mar 2026 16:48:57 -0400 Subject: [PATCH 09/13] Improve reporting of stretch values --- .../multisem/CreepCorrectionClient.java | 36 +++++++------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 823d09a37..5c2c51020 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -365,21 +365,22 @@ MfovResult validateStretches(final List stretches, final String mfov) { final double stddev = Math.sqrt(variance); if (stddev > MAX_STRETCH_STDDEV) { - return MfovResult.invalid(mfov, totalPairs, - "stretch stddev " + stddev + " exceeds threshold " + MAX_STRETCH_STDDEV + - " (median=" + medianStretch + " validCount=" + validStretches.size() + ")"); + return new MfovResult(mfov, medianStretch, stddev, Double.NaN, + validStretches.size(), totalPairs, null, + "stretch stddev exceeds threshold " + MAX_STRETCH_STDDEV); } // compute amplitude final double amplitude = computeCorrectionAmplitude(medianStretch); if (!Double.isFinite(amplitude)) { - return MfovResult.invalid(mfov, totalPairs, - "computed amplitude is not finite (median=" + medianStretch + ")"); + return new MfovResult(mfov, medianStretch, stddev, amplitude, + validStretches.size(), totalPairs, null, + "computed amplitude is not finite"); } return new MfovResult(mfov, medianStretch, stddev, amplitude, validStretches.size(), totalPairs, - buildCorrectionTransformSpec(amplitude)); + buildCorrectionTransformSpec(amplitude), "OK"); } /** @@ -533,34 +534,21 @@ public static class MfovResult implements Serializable { final double amplitude, final int validPairs, final int totalPairs, - final TransformSpec correctionSpec) { + final TransformSpec correctionSpec, + final String diagnosticMessage) { this.mfov = mfov; this.medianStretch = medianStretch; this.stddev = stddev; this.amplitude = amplitude; this.validPairs = validPairs; this.totalPairs = totalPairs; - this.isValid = true; - this.diagnosticMessage = "OK"; - this.correctionSpec = correctionSpec; - } - - private MfovResult(final String mfov, - final int totalPairs, - final String diagnosticMessage) { - this.mfov = mfov; - this.medianStretch = Double.NaN; - this.stddev = Double.NaN; - this.amplitude = Double.NaN; - this.validPairs = 0; - this.totalPairs = totalPairs; - this.isValid = false; + this.isValid = correctionSpec != null; this.diagnosticMessage = diagnosticMessage; - this.correctionSpec = null; + this.correctionSpec = correctionSpec; } static MfovResult invalid(final String mfov, final int totalPairs, final String reason) { - return new MfovResult(mfov, totalPairs, reason); + return new MfovResult(mfov, Double.NaN, Double.NaN, Double.NaN, 0, totalPairs, null, reason); } public String toCsvRow() { From dffcffb245da6b15178d68fc43e73a8f7bbe6e0e Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Sun, 29 Mar 2026 16:54:11 -0400 Subject: [PATCH 10/13] Add minimum correction threshold --- .../render/client/multisem/CreepCorrectionClient.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index 5c2c51020..ea5977ff5 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -61,6 +61,7 @@ public class CreepCorrectionClient { private static final double MAX_STRETCH = 1.1; private static final int MIN_VALID_STRETCHES = 10; private static final double MAX_STRETCH_STDDEV = 0.02; + private static final double MIN_MEDIAN_STRETCH = 0.995; // RANSAC parameters for pairwise affine estimation private static final int RANSAC_ITERATIONS = 1000; @@ -370,6 +371,12 @@ MfovResult validateStretches(final List stretches, final String mfov) { "stretch stddev exceeds threshold " + MAX_STRETCH_STDDEV); } + if (medianStretch > MIN_MEDIAN_STRETCH) { + return new MfovResult(mfov, medianStretch, stddev, Double.NaN, + validStretches.size(), totalPairs, null, + "median stretch " + medianStretch + " above threshold " + MIN_MEDIAN_STRETCH); + } + // compute amplitude final double amplitude = computeCorrectionAmplitude(medianStretch); if (!Double.isFinite(amplitude)) { From 8adf2a967af0573ba4bb55697d4bd757f67d4102 Mon Sep 17 00:00:00 2001 From: Michael Innerberger Date: Fri, 10 Apr 2026 12:02:32 -0400 Subject: [PATCH 11/13] Improve logging (always log correction parameters; csv or log) --- .../multisem/CreepCorrectionClient.java | 13 ++--- .../multisem/CreepCorrectionSparkClient.java | 50 +++++++++---------- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java index ea5977ff5..7c63ee335 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -117,14 +117,16 @@ public List processZLayer(final double z, final String mfov = entry.getKey(); final List mfovTiles = entry.getValue(); + LOG.info("processZLayer: processing mFOV {}", mfov); final MfovResult result = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); mfovResults.add(result); - LOG.info("processMFOV: {}", result.toCsvRow()); if (result.isValid) { correctedMFOVCount++; + LOG.info("processZLayer: successfully corrected mFOV {}", mfov); } else { skippedMFOVCount++; + LOG.info("processZLayer: failed correcting mFOV {}: {}", mfov, result.diagnosticMessage); } } @@ -333,15 +335,6 @@ MfovResult validateStretches(final List stretches, final String mfov) { } } - if (nanCount > totalPairs / 2) { - LOG.warn("validateStretches: mFOV {} has {} NaN stretches out of {} total", - mfov, nanCount, totalPairs); - } - if (outOfRangeCount > 0) { - LOG.info("validateStretches: mFOV {} had {} stretches outside [{}, {}]", - mfov, outOfRangeCount, MIN_STRETCH, MAX_STRETCH); - } - if (validStretches.size() < MIN_VALID_STRETCHES) { return MfovResult.invalid(mfov, totalPairs, "only " + validStretches.size() + " valid stretches (need " + MIN_VALID_STRETCHES + diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java index 24c126492..936a8ad0f 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -7,6 +7,7 @@ import java.io.PrintWriter; import java.io.Serializable; import java.util.ArrayList; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -158,14 +159,8 @@ public void run() throws IOException { LOG.info("run: Phase 1 complete - processed {} z-layers", zValues.size()); - // write parameter CSV if requested (non-fatal if it fails) - if (parameters.parameterCsv != null) { - try { - writeParameterCsv(parameters.parameterCsv, allResults); - } catch (final Exception e) { - LOG.error("run: failed to write parameter CSV to " + parameters.parameterCsv, e); - } - } + // write or log results on the driver + reportResults(allResults); // complete target stack on the driver sourceDataClient.setStackState(parameters.targetStack, StackMetaData.StackState.COMPLETE); @@ -203,9 +198,7 @@ private void transformMatches(final JavaSparkContext sparkContext, final JavaRDD rddGroupIds = sparkContext.parallelize(pGroupIds); - rddGroupIds.foreach(groupId -> { - transformMatchesForSingleGroup(groupId, broadcastResults.value()); - }); + rddGroupIds.foreach(groupId -> transformMatchesForSingleGroup(groupId, broadcastResults.value())); LOG.info("run: Phase 2 complete - transformed matches for {} groups", pGroupIds.size()); } @@ -248,27 +241,34 @@ private void transformMatchesForSingleGroup(final String groupId, targetMatchClient); } - private void writeParameterCsv(final String csvPath, - final Map> allResults) - throws IOException { - - LOG.info("writeParameterCsv: writing to {}", csvPath); - - // sort by scan (z) for deterministic output + private void reportResults(final Map> allResults) { final List sortedScans = new ArrayList<>(allResults.keySet()); - sortedScans.sort((a, b) -> Double.compare(Double.parseDouble(a), Double.parseDouble(b))); - - try (final PrintWriter writer = new PrintWriter(csvPath)) { - writer.println("scan," + MfovResult.CSV_HEADER); + sortedScans.sort(Comparator.comparingDouble(Double::parseDouble)); + + boolean written = false; + if (parameters.parameterCsv != null) { + try (final PrintWriter writer = new PrintWriter(parameters.parameterCsv)) { + writer.println("scan," + MfovResult.CSV_HEADER); + for (final String scan : sortedScans) { + for (final MfovResult result : allResults.get(scan)) { + writer.println(scan + "," + result.toCsvRow()); + } + } + written = true; + LOG.info("reportResults: wrote creep correction parameters to {}", parameters.parameterCsv); + } catch (final Exception e) { + LOG.error("reportResults: failed to write CSV to {}, logging instead", parameters.parameterCsv, e); + } + } + if (!written) { + LOG.info("reportResults: scan,{}", MfovResult.CSV_HEADER); for (final String scan : sortedScans) { for (final MfovResult result : allResults.get(scan)) { - writer.println(scan + "," + result.toCsvRow()); + LOG.info("reportResults: {},{}", scan, result.toCsvRow()); } } } - - LOG.info("writeParameterCsv: done"); } private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionSparkClient.class); From e9976145516823e6319d4cbf6d2f3e5ee06dd9bc Mon Sep 17 00:00:00 2001 From: Eric Trautman Date: Mon, 13 Apr 2026 09:39:22 -0400 Subject: [PATCH 12/13] first draft adapting creep correction code to fit within a multi-SEM alignment pipeline --- .../parameter/CreepCorrectionParameters.java | 87 +++++ .../multisem/CreepCorrectionSparkClient.java | 332 +++++++++++------- .../pipeline/AlignmentPipelineParameters.java | 9 + .../pipeline/AlignmentPipelineStepId.java | 2 + 4 files changed, 309 insertions(+), 121 deletions(-) create mode 100644 render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java b/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java new file mode 100644 index 000000000..10389cdf7 --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java @@ -0,0 +1,87 @@ +package org.janelia.render.client.parameter; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import java.io.File; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Parameters for correcting multi-SEM image creep. + */ +@Parameters +public class CreepCorrectionParameters + implements Serializable { + + @Parameter( + names = "--creepTargetStackSuffix", + description = "Target stack name is the the source stack name with this suffix appended") + public String targetStackSuffix = "_crc"; + + @Parameter( + names = "--creepMinZ", + description = "Minimum Z value for layers that need creep correction") + public Double minZ; + + @Parameter( + names = "--creepMaxZ", + description = "Maximum Z value for layers that need creep correction") + public Double maxZ; + + @Parameter( + names = "--skipMatchCorrection", + description = "Skip transforming match coordinates (default is to transform them)") + public boolean skipMatchCorrection = false; + + @Parameter( + names = "--parameterCsvDir", + description = "Directory where a creep-correction-param..csv file " + + "with per-mFOV parameters, stretch estimates, and validation results " + + "will be written for each source stack") + public String parameterCsvDir; + + public CreepCorrectionParameters() { + } + + public void validate() + throws IllegalArgumentException { + + if ((targetStackSuffix == null) || (targetStackSuffix.trim().isEmpty())) { + throw new IllegalArgumentException("--creepTargetStackSuffix must be defined"); + } + + if (parameterCsvDir != null) { + final File csvDir = new File(parameterCsvDir); + if (! csvDir.isDirectory()) { + throw new IllegalArgumentException("--parameterCsvDir " + parameterCsvDir + " is not a valid directory"); + } + } + } + + public String getTargetStack(final String sourceStack) { + return sourceStack + targetStackSuffix; + } + + + public List buildCorrectedZValues(final List allZValues) { + final List correctedZValues = new ArrayList<>(allZValues); + for (final Double z : allZValues) { + if ( (minZ == null || (z >= minZ)) && (maxZ == null || (z <= maxZ)) ) { + correctedZValues.add(z); + } + } + return correctedZValues; + } + + public List buildUncorrectedZValues(final List allZValues) { + final List uncorrectedZValues = new ArrayList<>(allZValues); + for (final Double z : allZValues) { + if ( (minZ != null && (z < minZ)) || (maxZ != null && (z > maxZ)) ) { + uncorrectedZValues.add(z); + } + } + return uncorrectedZValues; + } +} \ No newline at end of file diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java index 936a8ad0f..462751b89 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -1,8 +1,8 @@ package org.janelia.render.client.spark.multisem; -import com.beust.jcommander.Parameter; import com.beust.jcommander.ParametersDelegate; +import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.io.Serializable; @@ -16,15 +16,22 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; +import org.janelia.alignment.match.MatchCollectionId; +import org.janelia.alignment.spec.ResolvedTileSpecCollection; +import org.janelia.alignment.spec.stack.StackId; import org.janelia.alignment.spec.stack.StackMetaData; +import org.janelia.alignment.spec.stack.StackWithZValues; import org.janelia.render.client.ClientRunner; import org.janelia.render.client.RenderDataClient; import org.janelia.render.client.multisem.CreepCorrectionClient; import org.janelia.render.client.multisem.CreepCorrectionClient.MfovResult; import org.janelia.render.client.parameter.CommandLineParameters; -import org.janelia.render.client.parameter.RenderWebServiceParameters; -import org.janelia.render.client.parameter.ZRangeParameters; +import org.janelia.render.client.parameter.CreepCorrectionParameters; +import org.janelia.render.client.parameter.MultiProjectParameters; import org.janelia.render.client.spark.LogUtilities; +import org.janelia.render.client.spark.pipeline.AlignmentPipelineParameters; +import org.janelia.render.client.spark.pipeline.AlignmentPipelineStep; +import org.janelia.render.client.spark.pipeline.AlignmentPipelineStepId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,56 +52,16 @@ * * @author Michael Innerberger */ -public class CreepCorrectionSparkClient implements Serializable { +public class CreepCorrectionSparkClient + implements Serializable, AlignmentPipelineStep { public static class Parameters extends CommandLineParameters { @ParametersDelegate - public RenderWebServiceParameters renderWeb = new RenderWebServiceParameters(); + public MultiProjectParameters multiProject; @ParametersDelegate - public ZRangeParameters layerRange = new ZRangeParameters(); - - @Parameter( - names = "--stack", - description = "Name of source stack", - required = true) - public String stack; - - @Parameter( - names = "--targetStack", - description = "Name of target stack for corrected tiles", - required = true) - public String targetStack; - - @Parameter( - names = "--matchOwner", - description = "Owner of match collection (default is same as render owner)") - public String matchOwner; - - @Parameter( - names = "--matchCollection", - description = "Name of match collection containing within-layer montage matches", - required = true) - public String matchCollection; - - @Parameter( - names = "--skipMatchCorrection", - description = "Skip transforming match coordinates (default is to transform them)") - public boolean skipMatchCorrection = false; - - @Parameter( - names = "--parameterCsv", - description = "Path to write per-mFOV parameter CSV with stretch estimates and validation results") - public String parameterCsv; - - String getMatchOwner() { - return matchOwner != null ? matchOwner : renderWeb.owner; - } - - String getTargetMatchCollection() { - return targetStack + "_match"; - } + public CreepCorrectionParameters creepCorrection; } public static void main(final String[] args) { @@ -103,87 +70,169 @@ public static void main(final String[] args) { public void runClient(final String[] args) throws Exception { final Parameters parameters = new Parameters(); parameters.parse(args); + parameters.creepCorrection.validate(); LOG.info("runClient: entry, parameters={}", parameters); - final CreepCorrectionSparkClient client = new CreepCorrectionSparkClient(parameters); - client.run(); + final CreepCorrectionSparkClient client = new CreepCorrectionSparkClient(); + client.createContextAndRun(parameters); } }; clientRunner.run(); } - private final Parameters parameters; + public CreepCorrectionSparkClient() { + } - public CreepCorrectionSparkClient(final Parameters parameters) { - this.parameters = parameters; + /** + * Create a spark context and run the client with the specified parameters. + */ + public void createContextAndRun(final Parameters clientParameters) + throws IOException { + final SparkConf conf = new SparkConf().setAppName(getClass().getSimpleName()); + try (final JavaSparkContext sparkContext = new JavaSparkContext(conf)) { + + LOG.info("createContextAndRun: appId is {}", sparkContext.getConf().getAppId()); + + for (final StackWithZValues stackWithAllZ : clientParameters.multiProject.buildListOfStackWithAllZ()) { + correctCreep(sparkContext, + clientParameters.multiProject.getBaseDataUrl(), + stackWithAllZ, + clientParameters.creepCorrection, + clientParameters.multiProject.getMatchCollectionIdForStack(stackWithAllZ.getStackId())); + } + } } - public void run() throws IOException { + /** Validates the specified pipeline parameters are sufficient. */ + @Override + public void validatePipelineParameters(final AlignmentPipelineParameters pipelineParameters) + throws IllegalArgumentException { + final CreepCorrectionParameters creepCorrection = pipelineParameters.getCreepCorrection(); + AlignmentPipelineParameters.validateRequiredElementExists("creepCorrection", + creepCorrection); + creepCorrection.validate(); + } - final SparkConf conf = new SparkConf().setAppName("CreepCorrectionSparkClient"); + /** Run the client as part of an alignment pipeline. */ + @Override + public void runPipelineStep(final JavaSparkContext sparkContext, + final AlignmentPipelineParameters pipelineParameters) + throws IllegalArgumentException, IOException { + + final MultiProjectParameters multiProject = + pipelineParameters.getMultiProject(pipelineParameters.getRawNamingGroup()); + + for (final StackWithZValues stackWithAllZ : multiProject.buildListOfStackWithAllZ()) { + correctCreep(sparkContext, + multiProject.getBaseDataUrl(), + stackWithAllZ, + pipelineParameters.getCreepCorrection(), + multiProject.getMatchCollectionIdForStack(stackWithAllZ.getStackId())); + } + } - try (final JavaSparkContext sparkContext = new JavaSparkContext(conf)) { + @Override + public AlignmentPipelineStepId getDefaultStepId() { + return AlignmentPipelineStepId.CORRECT_CREEP; + } - final String sparkAppId = sparkContext.getConf().getAppId(); - final String executorsJson = LogUtilities.getExecutorsApiJson(sparkAppId); - LOG.info("run: appId is {}, executors data is {}", sparkAppId, executorsJson); + public void correctCreep(final JavaSparkContext sparkContext, + final String baseDataUrl, + final StackWithZValues stackWithAllZ, + final CreepCorrectionParameters creepCorrection, + final MatchCollectionId matchCollectionId) throws IOException { - final RenderDataClient sourceDataClient = parameters.renderWeb.getDataClient(); + final StackId sourceStackId = stackWithAllZ.getStackId(); + final String sourceStackDevString = sourceStackId.toDevString(); - final List zValues = sourceDataClient.getStackZValues(parameters.stack, - parameters.layerRange.minZ, - parameters.layerRange.maxZ); + LOG.info("correctCreep: entry, {} with z {} to {} and creepCorrection {}", + sourceStackDevString, stackWithAllZ.getFirstZ(), stackWithAllZ.getLastZ(), creepCorrection); - if (zValues.isEmpty()) { - throw new IllegalArgumentException("source stack does not contain any matching z values"); - } + final String sourceStack = sourceStackId.getStack(); + final String targetStack = creepCorrection.getTargetStack(sourceStack); + final MatchCollectionId targetMatchCollectionId = new MatchCollectionId(matchCollectionId.getOwner(), + targetStack + "_match"); - // set up target stack on the driver - final StackMetaData sourceStackMetaData = sourceDataClient.getStackMetaData(parameters.stack); - sourceDataClient.setupDerivedStack(sourceStackMetaData, parameters.targetStack); + final RenderDataClient sourceDataClient = new RenderDataClient(baseDataUrl, + sourceStackId.getOwner(), + sourceStackId.getProject()); - // Phase 1: process tiles and collect corrections - LOG.info("run: Phase 1 - distributing {} z values for tile correction", zValues.size()); + final List correctedZValues = creepCorrection.buildCorrectedZValues(stackWithAllZ.getzValues()); + if (correctedZValues.isEmpty()) { + throw new IllegalArgumentException("source stack does not contain any matching z values to be corrected"); + } - final JavaRDD rddZValues = sparkContext.parallelize(zValues); + // set up target stack on the driver + final StackMetaData sourceStackMetaData = sourceDataClient.getStackMetaData(sourceStack); + sourceDataClient.setupDerivedStack(sourceStackMetaData, targetStack); + + // Phase 1: process tiles and collect corrections + LOG.info("correctCreep: {}, phase 1 - distributing {} z values for tile correction", + sourceStackDevString, correctedZValues.size()); + + final JavaRDD rddCorrectedZValues = sparkContext.parallelize(correctedZValues); + + final JavaRDD> rddResults = + rddCorrectedZValues.map(z -> processSingleLayer( + baseDataUrl, + sourceStackId, + creepCorrection.targetStackSuffix, + z, + matchCollectionId)); + final List> resultList = rddResults.collect(); + + // collect all corrections on the driver + final Map> allResults = new HashMap<>(); + for (int i = 0; i < correctedZValues.size(); i++) { + allResults.put(String.valueOf(correctedZValues.get(i).doubleValue()), resultList.get(i)); + } - final JavaRDD> rddResults = rddZValues.map(this::processSingleLayer); - final List> resultList = rddResults.collect(); + final List uncorrectedZValues = creepCorrection.buildUncorrectedZValues(stackWithAllZ.getzValues()); + if (! uncorrectedZValues.isEmpty()) { - // collect all corrections on the driver - final Map> allResults = new HashMap<>(); - for (int i = 0; i < zValues.size(); i++) { - allResults.put(String.valueOf(zValues.get(i).doubleValue()), resultList.get(i)); - } + LOG.info("correctCreep: {}, phase 1 - distributing {} z values for simple copy", + sourceStackDevString, uncorrectedZValues.size()); - LOG.info("run: Phase 1 complete - processed {} z-layers", zValues.size()); + final JavaRDD rddUncorrectedZValues = sparkContext.parallelize(uncorrectedZValues); + final JavaRDD rddCopyResults = + rddUncorrectedZValues.map(z -> copySingleLayer( + baseDataUrl, + sourceStackId, + creepCorrection.targetStackSuffix, + z)); + rddCopyResults.collect(); + } - // write or log results on the driver - reportResults(allResults); + LOG.info("correctCreep: {}, phase 1 complete - processed {} z-layers", + sourceStackDevString, correctedZValues.size()); - // complete target stack on the driver - sourceDataClient.setStackState(parameters.targetStack, StackMetaData.StackState.COMPLETE); + // write or log results on the driver + reportResults(creepCorrection, sourceStackId, allResults); - // Phase 2: transform matches - if (!parameters.skipMatchCorrection) { - transformMatches(sparkContext, allResults); - } else { - LOG.info("run: skipping match correction (--skipMatchCorrection)"); - } + // complete target stack on the driver + sourceDataClient.setStackState(targetStack, StackMetaData.StackState.COMPLETE); + + // Phase 2: transform matches + if (! creepCorrection.skipMatchCorrection) { + transformMatches(sparkContext, baseDataUrl, matchCollectionId, targetMatchCollectionId, allResults); + } else { + LOG.info("correctCreep: skipping match correction"); } - LOG.info("run: exit"); + LOG.info("correctCreep: exit"); } private void transformMatches(final JavaSparkContext sparkContext, + final String baseDataUrl, + final MatchCollectionId matchCollectionId, + final MatchCollectionId targetMatchCollectionId, final Map> allResults) throws IOException { - final RenderDataClient driverMatchClient = new RenderDataClient( - parameters.renderWeb.baseDataUrl, - parameters.getMatchOwner(), - parameters.matchCollection); + final RenderDataClient driverMatchClient = new RenderDataClient(baseDataUrl, + matchCollectionId.getOwner(), + matchCollectionId.getName()); final List pGroupIds = driverMatchClient.getMatchPGroupIds(); @@ -192,47 +241,81 @@ private void transformMatches(final JavaSparkContext sparkContext, return; } - LOG.info("run: Phase 2 - distributing {} match groups for coordinate transformation", pGroupIds.size()); + LOG.info("transformMatches: Phase 2 - distributing {} match groups for coordinate transformation", + pGroupIds.size()); final Broadcast>> broadcastResults = sparkContext.broadcast(allResults); final JavaRDD rddGroupIds = sparkContext.parallelize(pGroupIds); - rddGroupIds.foreach(groupId -> transformMatchesForSingleGroup(groupId, broadcastResults.value())); + rddGroupIds.foreach(groupId -> + transformMatchesForSingleGroup(baseDataUrl, + matchCollectionId, + targetMatchCollectionId, + groupId, + broadcastResults.value())); - LOG.info("run: Phase 2 complete - transformed matches for {} groups", pGroupIds.size()); + LOG.info("transformMatches: Phase 2 complete - transformed matches for {} groups", + pGroupIds.size()); } - private List processSingleLayer(final Double z) throws IOException { - LogUtilities.setupExecutorLog4j("z " + z); + private List processSingleLayer(final String baseDataUrl, + final StackId stackId, + final String targetStackSuffix, + final Double z, + final MatchCollectionId matchCollectionId) throws IOException { + + LogUtilities.setupExecutorLog4j(stackId.toDevString() + "::z" + z); - final RenderDataClient executorRenderClient = parameters.renderWeb.getDataClient(); - final RenderDataClient executorMatchClient = new RenderDataClient( - parameters.renderWeb.baseDataUrl, - parameters.getMatchOwner(), - parameters.matchCollection); + final RenderDataClient executorRenderClient = new RenderDataClient(baseDataUrl, + stackId.getOwner(), + stackId.getProject()); + final RenderDataClient executorMatchClient = new RenderDataClient(baseDataUrl, + matchCollectionId.getOwner(), + matchCollectionId.getName()); final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); + final String targetStack = stackId.getStack() + targetStackSuffix; return correctionClient.processZLayer(z, executorRenderClient, executorMatchClient, - parameters.stack, - parameters.targetStack); + stackId.getStack(), + targetStack); + } + + private Double copySingleLayer(final String baseDataUrl, + final StackId stackId, + final String targetStackSuffix, + final Double z) throws IOException { + + LogUtilities.setupExecutorLog4j(stackId.toDevString() + "::z" + z); + + final RenderDataClient executorRenderClient = new RenderDataClient(baseDataUrl, + stackId.getOwner(), + stackId.getProject()); + final String targetStack = stackId.getStack() + targetStackSuffix; + + final ResolvedTileSpecCollection resolvedTiles = executorRenderClient.getResolvedTiles(stackId.getStack(), z); + executorRenderClient.saveResolvedTiles(resolvedTiles, targetStack, z); + + return z; } - private void transformMatchesForSingleGroup(final String groupId, + private void transformMatchesForSingleGroup(final String baseDataUrl, + final MatchCollectionId matchCollectionId, + final MatchCollectionId targetMatchCollectionId, + final String groupId, final Map> allResults) throws IOException { - LogUtilities.setupExecutorLog4j("matchTransform " + groupId); - final RenderDataClient sourceMatchClient = new RenderDataClient( - parameters.renderWeb.baseDataUrl, - parameters.getMatchOwner(), - parameters.matchCollection); - final RenderDataClient targetMatchClient = new RenderDataClient( - parameters.renderWeb.baseDataUrl, - parameters.getMatchOwner(), - parameters.getTargetMatchCollection()); + LogUtilities.setupExecutorLog4j(matchCollectionId.toDevString() + "::" + groupId); + + final RenderDataClient sourceMatchClient = new RenderDataClient(baseDataUrl, + matchCollectionId.getOwner(), + matchCollectionId.getName()); + final RenderDataClient targetMatchClient = new RenderDataClient(baseDataUrl, + targetMatchCollectionId.getOwner(), + targetMatchCollectionId.getName()); final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); correctionClient.transformMatchesForGroup(groupId, @@ -241,13 +324,20 @@ private void transformMatchesForSingleGroup(final String groupId, targetMatchClient); } - private void reportResults(final Map> allResults) { + private void reportResults(final CreepCorrectionParameters creepCorrection, + final StackId sourceStackId, + final Map> allResults) { + final List sortedScans = new ArrayList<>(allResults.keySet()); sortedScans.sort(Comparator.comparingDouble(Double::parseDouble)); boolean written = false; - if (parameters.parameterCsv != null) { - try (final PrintWriter writer = new PrintWriter(parameters.parameterCsv)) { + if (creepCorrection.parameterCsvDir != null) { + + final File csvFile = new File(creepCorrection.parameterCsvDir, + "creep-correction-param." + sourceStackId.getStack() + ".csv"); + + try (final PrintWriter writer = new PrintWriter(csvFile)) { writer.println("scan," + MfovResult.CSV_HEADER); for (final String scan : sortedScans) { for (final MfovResult result : allResults.get(scan)) { @@ -255,9 +345,9 @@ private void reportResults(final Map> allResults) { } } written = true; - LOG.info("reportResults: wrote creep correction parameters to {}", parameters.parameterCsv); + LOG.info("reportResults: wrote creep correction parameters to {}", csvFile); } catch (final Exception e) { - LOG.error("reportResults: failed to write CSV to {}, logging instead", parameters.parameterCsv, e); + LOG.error("reportResults: failed to write CSV to {}, logging instead", csvFile, e); } } diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java index 3e0951e88..04a1662d9 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java @@ -16,6 +16,7 @@ import org.janelia.alignment.util.UrlResourceUtil; import org.janelia.render.client.newsolver.setup.AffineBlockSolverSetup; import org.janelia.render.client.newsolver.setup.IntensityCorrectionSetup; +import org.janelia.render.client.parameter.CreepCorrectionParameters; import org.janelia.render.client.parameter.MFOVAsTileParameters; import org.janelia.render.client.parameter.MFOVMontageMatchPatchParameters; import org.janelia.render.client.parameter.MaskHackParameters; @@ -48,6 +49,7 @@ public class AlignmentPipelineParameters private final UnconnectedCrossMFOVParameters unconnectedCrossMfov; private final TileClusterParameters tileCluster; private final MatchCopyParameters matchCopy; + private final CreepCorrectionParameters creepCorrection; private final AffineBlockSolverSetup affineBlockSolverSetup; private final IntensityCorrectionSetup intensityCorrectionSetup; private final ZSpacingParameters zSpacing; @@ -73,6 +75,7 @@ public AlignmentPipelineParameters() { null, null, null, + null, null); } @@ -85,6 +88,7 @@ public AlignmentPipelineParameters(final MultiProjectParameters multiProject, final UnconnectedCrossMFOVParameters unconnectedCrossMfov, final TileClusterParameters tileCluster, final MatchCopyParameters matchCopy, + final CreepCorrectionParameters creepCorrection, final AffineBlockSolverSetup affineBlockSolverSetup, final IntensityCorrectionSetup intensityCorrectionSetup, final ZSpacingParameters zSpacing, @@ -101,6 +105,7 @@ public AlignmentPipelineParameters(final MultiProjectParameters multiProject, this.unconnectedCrossMfov = unconnectedCrossMfov; this.tileCluster = tileCluster; this.matchCopy = matchCopy; + this.creepCorrection = creepCorrection; this.affineBlockSolverSetup = affineBlockSolverSetup; this.intensityCorrectionSetup = intensityCorrectionSetup; this.zSpacing = zSpacing; @@ -156,6 +161,10 @@ public MatchCopyParameters getMatchCopy() { return matchCopy; } + public CreepCorrectionParameters getCreepCorrection() { + return creepCorrection; + } + public String getMatchCopyToCollectionSuffix() { return matchCopy == null ? "" : matchCopy.toCollectionSuffix; } diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java index 0af0b09b8..4f1627590 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java @@ -8,6 +8,7 @@ import org.janelia.render.client.spark.match.ClusterCountClient; import org.janelia.render.client.spark.match.CopyMatchClient; import org.janelia.render.client.spark.match.MultiStagePointMatchClient; +import org.janelia.render.client.spark.multisem.CreepCorrectionSparkClient; import org.janelia.render.client.spark.multisem.MFOVASTileClient; import org.janelia.render.client.spark.multisem.MFOVMontageMatchPatchClient; import org.janelia.render.client.spark.multisem.UnconnectedCrossMFOVClient; @@ -29,6 +30,7 @@ public enum AlignmentPipelineStepId { FIND_UNCONNECTED_CROSS_MFOVS(UnconnectedCrossMFOVClient::new), FIND_UNCONNECTED_TILES_AND_EDGES(ClusterCountClient::new), FILTER_MATCHES(CopyMatchClient::new), + CORRECT_CREEP(CreepCorrectionSparkClient::new), ALIGN_TILES(DistributedAffineBlockSolverClient::new), CORRECT_Z_POSITIONS(ZPositionCorrectionClient::new), CORRECT_INTENSITY(DistributedIntensityCorrectionBlockSolverClient::new), From 96ac9121f985eab7f1bb801ffff0cfc0f94b09e3 Mon Sep 17 00:00:00 2001 From: Eric Trautman Date: Mon, 13 Apr 2026 10:24:12 -0400 Subject: [PATCH 13/13] fix z value list bugs --- .../render/client/parameter/CreepCorrectionParameters.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java b/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java index 10389cdf7..6198b14c5 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java @@ -66,7 +66,7 @@ public String getTargetStack(final String sourceStack) { public List buildCorrectedZValues(final List allZValues) { - final List correctedZValues = new ArrayList<>(allZValues); + final List correctedZValues = new ArrayList<>(); for (final Double z : allZValues) { if ( (minZ == null || (z >= minZ)) && (maxZ == null || (z <= maxZ)) ) { correctedZValues.add(z); @@ -76,7 +76,7 @@ public List buildCorrectedZValues(final List allZValues) { } public List buildUncorrectedZValues(final List allZValues) { - final List uncorrectedZValues = new ArrayList<>(allZValues); + final List uncorrectedZValues = new ArrayList<>(); for (final Double z : allZValues) { if ( (minZ != null && (z < minZ)) || (maxZ != null && (z > maxZ)) ) { uncorrectedZValues.add(z);