diff --git a/render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java b/render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java new file mode 100644 index 000000000..003d90fc0 --- /dev/null +++ b/render-app/src/main/java/org/janelia/alignment/transform/StageCreepCorrectionTransform.java @@ -0,0 +1,78 @@ +package org.janelia.alignment.transform; + +import mpicbg.trakem2.transform.CoordinateTransform; + +/** + * Forward transform for piezo stage creep correction in multi-SEM imaging. + * + *

The backward map (used by the Python prototype via cv2.remap) is:

+ *
y_source = y_world + a * exp(-y_world / b) + c * exp(-y_world / d)
+ * + *

This class implements the exact forward (source → world) transform by numerically + * inverting the backward map using Newton's method: given {@code y_s}, find {@code y_w} such that + * {@code y_w + a * exp(-y_w / b) + c * exp(-y_w / d) = y_s}.

+ * + *

Coefficients: a, b, c, d (same as {@link SEMDistortionTransformA}). + * Data string format: {@code "a,b,c,d,dimension"}.

+ */ +public class StageCreepCorrectionTransform + extends MultiParameterSingleDimensionTransform { + + private static final int MAX_ITERATIONS = 10; + private static final double TOLERANCE = 1e-9; + + public StageCreepCorrectionTransform() { + this(0, 0, 0, 0, 0); + } + + public StageCreepCorrectionTransform(final double a, + final double b, + final double c, + final double d, + final int dimension) { + super(new double[] {a, b, c, d}, dimension); + } + + @Override + public int getNumberOfCoefficients() { + return 4; + } + + /** + * Forward transform (source → world): given y_s, computes y_w by solving + * {@code y_w + a * exp(-y_w / b) + c * exp(-y_w / d) = y_s} using Newton's method. + */ + @Override + public void applyInPlace(final double[] location) { + final double a = coefficients[0]; + final double b = coefficients[1]; + final double c = coefficients[2]; + final double d = coefficients[3]; + final double y_s = location[dimension]; + + // solve g(y_w) = y_w + a*exp(-y_w/b) + c*exp(-y_w/d) - y_s = 0 + // g'(y_w) = 1 - (a/b)*exp(-y_w/b) - (c/d)*exp(-y_w/d) + double y_w = y_s; // initial guess + for (int i = 0; i < MAX_ITERATIONS; i++) { + final double expB = Math.exp(-y_w / b); + final double expD = Math.exp(-y_w / d); + final double g = y_w + a * expB + c * expD - y_s; + if (Math.abs(g) < TOLERANCE) { + break; + } + final double gPrime = 1.0 - (a / b) * expB - (c / d) * expD; + y_w -= g / gPrime; + } + + location[dimension] = y_w; + } + + @Override + public CoordinateTransform copy() { + return new StageCreepCorrectionTransform(coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + dimension); + } +} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java new file mode 100644 index 000000000..7c63ee335 --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/multisem/CreepCorrectionClient.java @@ -0,0 +1,582 @@ +package org.janelia.render.client.multisem; + +import static org.janelia.alignment.spec.ResolvedTileSpecCollection.TransformApplicationMethod.INSERT_BEFORE_LAST; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.janelia.alignment.match.CanvasMatchResult; +import org.janelia.alignment.match.CanvasMatches; +import org.janelia.alignment.match.Matches; +import org.janelia.alignment.multisem.MultiSemUtilities; +import org.janelia.alignment.spec.LeafTransformSpec; +import org.janelia.alignment.spec.LayoutData; +import org.janelia.alignment.spec.ResolvedTileSpecCollection; +import org.janelia.alignment.spec.TileSpec; +import org.janelia.alignment.spec.TransformSpec; +import org.janelia.alignment.transform.StageCreepCorrectionTransform; +import org.janelia.render.client.RenderDataClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import mpicbg.models.AffineModel2D; +import mpicbg.models.CoordinateTransform; +import mpicbg.models.NotEnoughDataPointsException; +import mpicbg.models.PointMatch; + +/** + * Estimates and applies piezo creep correction for multi-SEM sFOV tiles within each mFOV. + * + *

The correction is estimated by fitting pairwise affine transforms between geometrically + * adjacent sFOV pairs and extracting the y-stretch factor. A double-exponential correction + * (using {@link org.janelia.alignment.transform.SEMDistortionTransformA}) is then derived + * and inserted into each tile's transform list.

+ * + * @author Michael Innerberger + */ +public class CreepCorrectionClient { + + // physical constants for the double-exponential model + private static final double TAU_0 = 0.42; // seconds + private static final double TAU_1 = 4.0; // seconds + private static final double PIXEL_DWELL = 800e-9; // seconds + private static final double LINE_WIDTH = 2000; // pixels + private static final double L0 = TAU_0 / (PIXEL_DWELL * LINE_WIDTH); // 262500 lines + private static final double L1 = TAU_1 / (PIXEL_DWELL * LINE_WIDTH); // 2500000 lines + + // line numbers used for stretch-to-amplitude conversion + private static final int TOP_LINE = 50; + private static final int BOTTOM_LINE = 1700; + + // validation thresholds + private static final double MIN_STRETCH = 0.9; + private static final double MAX_STRETCH = 1.1; + private static final int MIN_VALID_STRETCHES = 10; + private static final double MAX_STRETCH_STDDEV = 0.02; + private static final double MIN_MEDIAN_STRETCH = 0.995; + + // RANSAC parameters for pairwise affine estimation + private static final int RANSAC_ITERATIONS = 1000; + private static final double RANSAC_MAX_EPSILON = 10.0; + private static final double RANSAC_MIN_INLIER_RATIO = 0.1; + private static final int RANSAC_MIN_NUM_INLIERS = 7; + private static final double RANSAC_MAX_TRUST = 3.0; + + // minimum offset for geometric neighbor detection (in stage coordinate units) + private static final double MIN_NEIGHBOR_OFFSET = 100.0; + + /** + * Processes all mFOVs for a given z-layer: loads tiles and matches, estimates creep correction, + * applies it where valid, and saves the results to the target stack. + * + * @return per-mFOV results (one per mFOV, including skipped ones) + */ + public List processZLayer(final double z, + final RenderDataClient renderDataClient, + final RenderDataClient matchDataClient, + final String stack, + final String targetStack) + throws IOException { + + LOG.info("processZLayer: entry, z={}", z); + + final ResolvedTileSpecCollection resolvedTiles = renderDataClient.getResolvedTiles(stack, z); + + // group tiles by mFOV + final Map> mfovToTiles = new HashMap<>(); + for (final TileSpec tileSpec : resolvedTiles.getTileSpecs()) { + final String mfov = MultiSemUtilities.getMagcMfovForTileId(tileSpec.getTileId()); + mfovToTiles.computeIfAbsent(mfov, k -> new ArrayList<>()).add(tileSpec); + } + + // load all within-group matches and index by tile pair + final String groupId = String.valueOf(z); + final List allMatches = matchDataClient.getMatchesWithinGroup(groupId, false); + final Map pairKeyToMatches = new HashMap<>(); + for (final CanvasMatches match : allMatches) { + final String pMFOV = MultiSemUtilities.getMagcMfovForTileId(match.getpId()); + final String qMFOV = MultiSemUtilities.getMagcMfovForTileId(match.getqId()); + if (pMFOV.equals(qMFOV)) { + pairKeyToMatches.put(pairKey(match.getpId(), match.getqId()), match); + } + } + + // process each mFOV independently + final List mfovResults = new ArrayList<>(); + int correctedMFOVCount = 0; + int skippedMFOVCount = 0; + for (final Map.Entry> entry : mfovToTiles.entrySet()) { + final String mfov = entry.getKey(); + final List mfovTiles = entry.getValue(); + + LOG.info("processZLayer: processing mFOV {}", mfov); + final MfovResult result = processMFOV(mfov, mfovTiles, pairKeyToMatches, resolvedTiles); + mfovResults.add(result); + + if (result.isValid) { + correctedMFOVCount++; + LOG.info("processZLayer: successfully corrected mFOV {}", mfov); + } else { + skippedMFOVCount++; + LOG.info("processZLayer: failed correcting mFOV {}: {}", mfov, result.diagnosticMessage); + } + } + + // save all tiles (both corrected and uncorrected) + renderDataClient.saveResolvedTiles(resolvedTiles, targetStack, z); + + LOG.info("processZLayer: exit, z={}, correctedMFOVs={}, skippedMFOVs={}, totalTiles={}", + z, correctedMFOVCount, skippedMFOVCount, resolvedTiles.getTileCount()); + + return mfovResults; + } + + /** + * Processes a single mFOV: finds neighbor pairs, estimates stretch, validates, and applies correction. + */ + MfovResult processMFOV(final String mfov, + final List mfovTiles, + final Map pairKeyToMatches, + final ResolvedTileSpecCollection resolvedTiles) { + + // find geometric neighbor pairs + final List neighborPairs = findGeometricNeighborPairs(mfovTiles); + + if (neighborPairs.isEmpty()) { + return MfovResult.invalid(mfov, 0, "no geometric neighbor pairs found"); + } + + // estimate y-stretch for each pair that has matches + final List stretches = new ArrayList<>(); + int pairsWithoutMatches = 0; + for (final TilePair pair : neighborPairs) { + final String key = pairKey(pair.pTileId, pair.qTileId); + final CanvasMatches matches = pairKeyToMatches.get(key); + + if (matches == null) { + pairsWithoutMatches++; + continue; + } + + final double yStretch = estimateYStretchForPair( + CanvasMatchResult.convertMatchesToPointMatchList(matches.getMatches())); + stretches.add(yStretch); + } + + if (pairsWithoutMatches > 0) { + LOG.info("processMFOV: {} of {} neighbor pairs in mFOV {} had no matches", + pairsWithoutMatches, neighborPairs.size(), mfov); + } + + // validate and build result + final MfovResult result = validateStretches(stretches, mfov); + + if (result.isValid) { + final Set mfovTileIds = mfovTiles.stream() + .map(TileSpec::getTileId) + .collect(Collectors.toSet()); + applyCorrectionToMFOV(resolvedTiles, mfovTileIds, result.correctionSpec); + } + + return result; + } + + /** + * Finds geometric neighbor pairs (lower-right and lower-left) using stage positions, + * replicating the Python prototype's neighbor-finding logic. + */ + List findGeometricNeighborPairs(final List mfovTiles) { + final List pairs = new ArrayList<>(); + + for (final TileSpec targetTile : mfovTiles) { + final StageCoordinates target = getStageCoordinates(targetTile); + if (target == null) continue; + + // find closest lower-right neighbor + findClosestNeighbor(targetTile.getTileId(), target, mfovTiles, true) + .ifPresent(pairs::add); + + // find closest lower-left neighbor + findClosestNeighbor(targetTile.getTileId(), target, mfovTiles, false) + .ifPresent(pairs::add); + } + + return pairs; + } + + private static StageCoordinates getStageCoordinates(final TileSpec targetTile) { + final LayoutData targetLayout = targetTile.getLayout(); + if (targetLayout == null || targetLayout.getStageX() == null || targetLayout.getStageY() == null) { + return null; + } + final double targetX = targetLayout.getStageX(); + final double targetY = targetLayout.getStageY(); + return new StageCoordinates(targetX, targetY); + } + + private Optional findClosestNeighbor(final String targetTileId, + final StageCoordinates target, + final List candidates, + final boolean lowerRight) { + double minDist = Double.MAX_VALUE; + String closestId = null; + + for (final TileSpec candidate : candidates) { + if (candidate.getTileId().equals(targetTileId)) { + continue; + } + final LayoutData layout = candidate.getLayout(); + final StageCoordinates stage = getStageCoordinates(candidate); + + if (layout == null || stage == null) { + continue; + } + + final boolean xCondition = lowerRight + ? (stage.x > target.x + MIN_NEIGHBOR_OFFSET) + : (stage.x < target.x - MIN_NEIGHBOR_OFFSET); + + if (xCondition && (stage.y > target.y + MIN_NEIGHBOR_OFFSET)) { + final double dist = Math.sqrt(Math.pow(stage.x - target.x, 2) + Math.pow(stage.y - target.y, 2)); + if (dist < minDist) { + minDist = dist; + closestId = candidate.getTileId(); + } + } + } + + if (closestId != null) { + // normalize pair ordering so pairKey lookups work regardless of match storage order + final String pId = targetTileId.compareTo(closestId) < 0 ? targetTileId : closestId; + final String qId = targetTileId.compareTo(closestId) < 0 ? closestId : targetTileId; + return java.util.Optional.of(new TilePair(pId, qId)); + } + return java.util.Optional.empty(); + } + + /** + * Estimates the y-stretch factor from a set of point matches by fitting an affine model with RANSAC. + * + * @return the y-stretch (m11 element of the affine), or {@link Double#NaN} on failure + */ + double estimateYStretchForPair(final List candidates) { + if (candidates.size() < RANSAC_MIN_NUM_INLIERS) { + return Double.NaN; + } + + final AffineModel2D model = new AffineModel2D(); + final List inliers = new ArrayList<>(); + + try { + model.filterRansac(candidates, + inliers, + RANSAC_ITERATIONS, + RANSAC_MAX_EPSILON, + RANSAC_MIN_INLIER_RATIO, + RANSAC_MIN_NUM_INLIERS, + RANSAC_MAX_TRUST); + } catch (final NotEnoughDataPointsException e) { + return Double.NaN; + } + + if (inliers.isEmpty()) { + return Double.NaN; + } + + final double[] affineData = new double[6]; + model.toArray(affineData); + // affineData layout: [m00, m10, m01, m11, m02, m12] + return affineData[3]; // m11 = y-stretch + } + + /** + * Computes the creep correction amplitude from the median y-stretch. + * Returns a negative value (same sign as the Python prototype's {@code calculate_a_helper_func}), + * which is used directly as the coefficient in {@link StageCreepCorrectionTransform}'s backward map: + * {@code y_source = y_world + a * exp(-y_world / L)}. + */ + static double computeCorrectionAmplitude(final double medianStretch) { + final double b = (-1.0 / L0) * Math.exp(-(double) TOP_LINE / L0) + + (-1.0 / L1) * Math.exp(-(double) TOP_LINE / L1); + final double c = (-1.0 / L0) * Math.exp(-(double) BOTTOM_LINE / L0) + + (-1.0 / L1) * Math.exp(-(double) BOTTOM_LINE / L1); + return (medianStretch - 1.0) / (b - c * medianStretch); + } + + /** + * Validates stretch estimates for a single mFOV and returns the result. + */ + MfovResult validateStretches(final List stretches, final String mfov) { + final int totalPairs = stretches.size(); + + if (totalPairs == 0) { + return MfovResult.invalid(mfov, 0, "no stretch estimates available"); + } + + // filter out NaN and out-of-range values + final List validStretches = new ArrayList<>(); + int nanCount = 0; + int outOfRangeCount = 0; + for (final double s : stretches) { + if (Double.isNaN(s)) { + nanCount++; + } else if (s < MIN_STRETCH || s > MAX_STRETCH) { + outOfRangeCount++; + } else { + validStretches.add(s); + } + } + + if (validStretches.size() < MIN_VALID_STRETCHES) { + return MfovResult.invalid(mfov, totalPairs, + "only " + validStretches.size() + " valid stretches (need " + MIN_VALID_STRETCHES + + ") nanCount=" + nanCount + " outOfRange=" + outOfRangeCount); + } + + // compute median + validStretches.sort(Double::compareTo); + final double medianStretch; + final int n = validStretches.size(); + if (n % 2 == 0) { + medianStretch = (validStretches.get(n / 2 - 1) + validStretches.get(n / 2)) / 2.0; + } else { + medianStretch = validStretches.get(n / 2); + } + + // compute standard deviation + final double mean = validStretches.stream().mapToDouble(Double::doubleValue).average().orElse(0.0); + final double variance = validStretches.stream() + .mapToDouble(s -> Math.pow(s - mean, 2)) + .sum() / validStretches.size(); + final double stddev = Math.sqrt(variance); + + if (stddev > MAX_STRETCH_STDDEV) { + return new MfovResult(mfov, medianStretch, stddev, Double.NaN, + validStretches.size(), totalPairs, null, + "stretch stddev exceeds threshold " + MAX_STRETCH_STDDEV); + } + + if (medianStretch > MIN_MEDIAN_STRETCH) { + return new MfovResult(mfov, medianStretch, stddev, Double.NaN, + validStretches.size(), totalPairs, null, + "median stretch " + medianStretch + " above threshold " + MIN_MEDIAN_STRETCH); + } + + // compute amplitude + final double amplitude = computeCorrectionAmplitude(medianStretch); + if (!Double.isFinite(amplitude)) { + return new MfovResult(mfov, medianStretch, stddev, amplitude, + validStretches.size(), totalPairs, null, + "computed amplitude is not finite"); + } + + return new MfovResult(mfov, medianStretch, stddev, amplitude, + validStretches.size(), totalPairs, + buildCorrectionTransformSpec(amplitude), "OK"); + } + + /** + * Builds a {@link LeafTransformSpec} for the creep correction using {@code StageCreepCorrectionTransform}. + * The amplitude is negative (same as the Python prototype), defining the backward map + * {@code y_s = y_w + a*exp(-y_w/L0) + a*exp(-y_w/L1)} where negative {@code a} reads from above, + * compressing the top of the image to correct for creep stretch. + */ + static TransformSpec buildCorrectionTransformSpec(final double amplitude) { + final String dataString = amplitude + "," + L0 + "," + amplitude + "," + L1 + ",1"; + return new LeafTransformSpec("org.janelia.alignment.transform.StageCreepCorrectionTransform", dataString); + } + + /** + * Applies the correction transform to all tiles of an mFOV by inserting it before the last + * (alignment) transform, placing it after the existing scan correction. + */ + private void applyCorrectionToMFOV(final ResolvedTileSpecCollection tiles, + final Set mfovTileIds, + final TransformSpec correctionSpec) { + for (final String tileId : mfovTileIds) { + tiles.addTransformSpecToTile(tileId, correctionSpec, INSERT_BEFORE_LAST); + } + } + + /** + * Creates a canonical pair key for match lookups. CanvasMatches normalizes p < q, + * so we ensure the same ordering. + */ + private static String pairKey(final String id1, final String id2) { + return id1.compareTo(id2) < 0 ? id1 + "::" + id2 : id2 + "::" + id1; + } + + /** + * Transforms all matches for a given group (z-layer) using the collected creep corrections. + * Handles both within-group and outside-group matches. + */ + public void transformMatchesForGroup(final String groupId, + final Map> allResults, + final RenderDataClient sourceMatchClient, + final RenderDataClient targetMatchClient) + throws IOException { + + LOG.info("transformMatchesForGroup: entry, groupId={}", groupId); + + // get within-group and outside-group matches + final List withinMatches = sourceMatchClient.getMatchesWithinGroup(groupId, false); + final List outsideMatches = sourceMatchClient.getMatchesOutsideGroup(groupId, false); + + final List allMatches = new ArrayList<>(withinMatches); + allMatches.addAll(outsideMatches); + + if (allMatches.isEmpty()) { + LOG.info("transformMatchesForGroup: no matches found for groupId={}", groupId); + return; + } + + // transform and collect matches + final List transformedMatches = new ArrayList<>(); + for (final CanvasMatches cm : allMatches) { + transformedMatches.add(transformCanvasMatches(cm, allResults)); + } + + // save to target match collection + targetMatchClient.saveMatches(transformedMatches); + + LOG.info("transformMatchesForGroup: exit, groupId={}, transformed {} matches", + groupId, transformedMatches.size()); + } + + /** + * Transforms a single CanvasMatches by applying creep corrections to both p and q coordinates. + */ + static CanvasMatches transformCanvasMatches(final CanvasMatches cm, + final Map> allResults) { + final Matches matches = cm.getMatches(); + final double[][] ps = matches.getPs(); + final double[][] qs = matches.getQs(); + final double[] ws = matches.getWs(); + final int n = ws.length; + + // deep copy coordinate arrays + final double[][] newPs = new double[][] { Arrays.copyOf(ps[0], n), Arrays.copyOf(ps[1], n) }; + final double[][] newQs = new double[][] { Arrays.copyOf(qs[0], n), Arrays.copyOf(qs[1], n) }; + + // transform p and q coordinates using their respective tile's creep correction + transformMatchCoordinates(newPs, cm.getpId(), cm.getpGroupId(), allResults); + transformMatchCoordinates(newQs, cm.getqId(), cm.getqGroupId(), allResults); + + return new CanvasMatches(cm.getpGroupId(), cm.getpId(), + cm.getqGroupId(), cm.getqId(), + new Matches(newPs, newQs, Arrays.copyOf(ws, n))); + } + + /** + * Transforms match coordinates in place by applying creep correction. + * Match coordinates are in post-lens/pre-alignment space (montage matches are derived before + * alignment), which is the same space where the CC transform operates (it is inserted + * before the alignment transform). So we apply CC directly without touching alignment. + */ + static void transformMatchCoordinates(final double[][] coords, + final String tileId, + final String groupId, + final Map> allResults) { + + // look up CC spec for this tile's mFOV + final String mfov = MultiSemUtilities.getMagcMfovForTileId(tileId); + final List layerResults = allResults.get(groupId); + if (layerResults == null) { + return; + } + + TransformSpec ccSpec = null; + for (final MfovResult r : layerResults) { + if (r.mfov.equals(mfov) && r.correctionSpec != null) { + ccSpec = r.correctionSpec; + break; + } + } + if (ccSpec == null) { + return; + } + + final CoordinateTransform ccTransform = ccSpec.getNewInstance(); + for (int i = 0; i < coords[0].length; i++) { + final double[] point = new double[] { coords[0][i], coords[1][i] }; + ccTransform.applyInPlace(point); + coords[0][i] = point[0]; + coords[1][i] = point[1]; + } + } + + /** Result of processing a single mFOV, including validation outcome and correction spec. */ + public static class MfovResult implements Serializable { + + public static final String CSV_HEADER = "mfov,medianStretch,stddev,amplitude,validPairs,totalPairs,isValid,diagnosticMessage"; + + public final String mfov; + public final double medianStretch; + public final double stddev; + public final double amplitude; + public final int validPairs; + public final int totalPairs; + public final boolean isValid; + public final String diagnosticMessage; + public final TransformSpec correctionSpec; + + MfovResult(final String mfov, + final double medianStretch, + final double stddev, + final double amplitude, + final int validPairs, + final int totalPairs, + final TransformSpec correctionSpec, + final String diagnosticMessage) { + this.mfov = mfov; + this.medianStretch = medianStretch; + this.stddev = stddev; + this.amplitude = amplitude; + this.validPairs = validPairs; + this.totalPairs = totalPairs; + this.isValid = correctionSpec != null; + this.diagnosticMessage = diagnosticMessage; + this.correctionSpec = correctionSpec; + } + + static MfovResult invalid(final String mfov, final int totalPairs, final String reason) { + return new MfovResult(mfov, Double.NaN, Double.NaN, Double.NaN, 0, totalPairs, null, reason); + } + + public String toCsvRow() { + return mfov + "," + medianStretch + "," + stddev + "," + amplitude + "," + + validPairs + "," + totalPairs + "," + isValid + "," + diagnosticMessage; + } + } + + /** A pair of tile IDs representing geometrically adjacent sFOVs. */ + static class TilePair { + final String pTileId; + final String qTileId; + + TilePair(final String pTileId, final String qTileId) { + this.pTileId = pTileId; + this.qTileId = qTileId; + } + } + + static class StageCoordinates { + final double x; + final double y; + + StageCoordinates(final double x, final double y) { + this.x = x; + this.y = y; + } + } + + private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionClient.class); +} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java b/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java new file mode 100644 index 000000000..6198b14c5 --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/parameter/CreepCorrectionParameters.java @@ -0,0 +1,87 @@ +package org.janelia.render.client.parameter; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import java.io.File; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Parameters for correcting multi-SEM image creep. + */ +@Parameters +public class CreepCorrectionParameters + implements Serializable { + + @Parameter( + names = "--creepTargetStackSuffix", + description = "Target stack name is the the source stack name with this suffix appended") + public String targetStackSuffix = "_crc"; + + @Parameter( + names = "--creepMinZ", + description = "Minimum Z value for layers that need creep correction") + public Double minZ; + + @Parameter( + names = "--creepMaxZ", + description = "Maximum Z value for layers that need creep correction") + public Double maxZ; + + @Parameter( + names = "--skipMatchCorrection", + description = "Skip transforming match coordinates (default is to transform them)") + public boolean skipMatchCorrection = false; + + @Parameter( + names = "--parameterCsvDir", + description = "Directory where a creep-correction-param..csv file " + + "with per-mFOV parameters, stretch estimates, and validation results " + + "will be written for each source stack") + public String parameterCsvDir; + + public CreepCorrectionParameters() { + } + + public void validate() + throws IllegalArgumentException { + + if ((targetStackSuffix == null) || (targetStackSuffix.trim().isEmpty())) { + throw new IllegalArgumentException("--creepTargetStackSuffix must be defined"); + } + + if (parameterCsvDir != null) { + final File csvDir = new File(parameterCsvDir); + if (! csvDir.isDirectory()) { + throw new IllegalArgumentException("--parameterCsvDir " + parameterCsvDir + " is not a valid directory"); + } + } + } + + public String getTargetStack(final String sourceStack) { + return sourceStack + targetStackSuffix; + } + + + public List buildCorrectedZValues(final List allZValues) { + final List correctedZValues = new ArrayList<>(); + for (final Double z : allZValues) { + if ( (minZ == null || (z >= minZ)) && (maxZ == null || (z <= maxZ)) ) { + correctedZValues.add(z); + } + } + return correctedZValues; + } + + public List buildUncorrectedZValues(final List allZValues) { + final List uncorrectedZValues = new ArrayList<>(); + for (final Double z : allZValues) { + if ( (minZ != null && (z < minZ)) || (maxZ != null && (z > maxZ)) ) { + uncorrectedZValues.add(z); + } + } + return uncorrectedZValues; + } +} \ No newline at end of file diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java new file mode 100644 index 000000000..462751b89 --- /dev/null +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/multisem/CreepCorrectionSparkClient.java @@ -0,0 +1,365 @@ +package org.janelia.render.client.spark.multisem; + +import com.beust.jcommander.ParametersDelegate; + +import java.io.File; +import java.io.IOException; +import java.io.PrintWriter; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.janelia.alignment.match.MatchCollectionId; +import org.janelia.alignment.spec.ResolvedTileSpecCollection; +import org.janelia.alignment.spec.stack.StackId; +import org.janelia.alignment.spec.stack.StackMetaData; +import org.janelia.alignment.spec.stack.StackWithZValues; +import org.janelia.render.client.ClientRunner; +import org.janelia.render.client.RenderDataClient; +import org.janelia.render.client.multisem.CreepCorrectionClient; +import org.janelia.render.client.multisem.CreepCorrectionClient.MfovResult; +import org.janelia.render.client.parameter.CommandLineParameters; +import org.janelia.render.client.parameter.CreepCorrectionParameters; +import org.janelia.render.client.parameter.MultiProjectParameters; +import org.janelia.render.client.spark.LogUtilities; +import org.janelia.render.client.spark.pipeline.AlignmentPipelineParameters; +import org.janelia.render.client.spark.pipeline.AlignmentPipelineStep; +import org.janelia.render.client.spark.pipeline.AlignmentPipelineStepId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Spark client for applying piezo creep correction to multi-SEM tiles. + * Each z-layer is processed independently on a Spark executor. + * + *

Within each z-layer, all mFOVs are processed sequentially: for each mFOV, the y-stretch + * is estimated from pairwise affine fits of geometrically adjacent sFOV pairs, and a + * double-exponential correction is applied if validation passes. mFOVs that fail validation + * are skipped (uploaded without correction).

+ * + *

After tile correction, existing point matches are transformed to account for the creep + * correction and saved to a new match collection (named {@code targetStack + "_match"}). + * This can be skipped with {@code --skipMatchCorrection}.

+ * + * @see CreepCorrectionClient + * + * @author Michael Innerberger + */ +public class CreepCorrectionSparkClient + implements Serializable, AlignmentPipelineStep { + + public static class Parameters extends CommandLineParameters { + + @ParametersDelegate + public MultiProjectParameters multiProject; + + @ParametersDelegate + public CreepCorrectionParameters creepCorrection; + } + + public static void main(final String[] args) { + final ClientRunner clientRunner = new ClientRunner(args) { + @Override + public void runClient(final String[] args) throws Exception { + final Parameters parameters = new Parameters(); + parameters.parse(args); + parameters.creepCorrection.validate(); + + LOG.info("runClient: entry, parameters={}", parameters); + + final CreepCorrectionSparkClient client = new CreepCorrectionSparkClient(); + client.createContextAndRun(parameters); + } + }; + clientRunner.run(); + } + + public CreepCorrectionSparkClient() { + } + + /** + * Create a spark context and run the client with the specified parameters. + */ + public void createContextAndRun(final Parameters clientParameters) + throws IOException { + final SparkConf conf = new SparkConf().setAppName(getClass().getSimpleName()); + try (final JavaSparkContext sparkContext = new JavaSparkContext(conf)) { + + LOG.info("createContextAndRun: appId is {}", sparkContext.getConf().getAppId()); + + for (final StackWithZValues stackWithAllZ : clientParameters.multiProject.buildListOfStackWithAllZ()) { + correctCreep(sparkContext, + clientParameters.multiProject.getBaseDataUrl(), + stackWithAllZ, + clientParameters.creepCorrection, + clientParameters.multiProject.getMatchCollectionIdForStack(stackWithAllZ.getStackId())); + } + } + } + + /** Validates the specified pipeline parameters are sufficient. */ + @Override + public void validatePipelineParameters(final AlignmentPipelineParameters pipelineParameters) + throws IllegalArgumentException { + final CreepCorrectionParameters creepCorrection = pipelineParameters.getCreepCorrection(); + AlignmentPipelineParameters.validateRequiredElementExists("creepCorrection", + creepCorrection); + creepCorrection.validate(); + } + + /** Run the client as part of an alignment pipeline. */ + @Override + public void runPipelineStep(final JavaSparkContext sparkContext, + final AlignmentPipelineParameters pipelineParameters) + throws IllegalArgumentException, IOException { + + final MultiProjectParameters multiProject = + pipelineParameters.getMultiProject(pipelineParameters.getRawNamingGroup()); + + for (final StackWithZValues stackWithAllZ : multiProject.buildListOfStackWithAllZ()) { + correctCreep(sparkContext, + multiProject.getBaseDataUrl(), + stackWithAllZ, + pipelineParameters.getCreepCorrection(), + multiProject.getMatchCollectionIdForStack(stackWithAllZ.getStackId())); + } + } + + @Override + public AlignmentPipelineStepId getDefaultStepId() { + return AlignmentPipelineStepId.CORRECT_CREEP; + } + + public void correctCreep(final JavaSparkContext sparkContext, + final String baseDataUrl, + final StackWithZValues stackWithAllZ, + final CreepCorrectionParameters creepCorrection, + final MatchCollectionId matchCollectionId) throws IOException { + + final StackId sourceStackId = stackWithAllZ.getStackId(); + final String sourceStackDevString = sourceStackId.toDevString(); + + LOG.info("correctCreep: entry, {} with z {} to {} and creepCorrection {}", + sourceStackDevString, stackWithAllZ.getFirstZ(), stackWithAllZ.getLastZ(), creepCorrection); + + final String sourceStack = sourceStackId.getStack(); + final String targetStack = creepCorrection.getTargetStack(sourceStack); + final MatchCollectionId targetMatchCollectionId = new MatchCollectionId(matchCollectionId.getOwner(), + targetStack + "_match"); + + final RenderDataClient sourceDataClient = new RenderDataClient(baseDataUrl, + sourceStackId.getOwner(), + sourceStackId.getProject()); + + final List correctedZValues = creepCorrection.buildCorrectedZValues(stackWithAllZ.getzValues()); + if (correctedZValues.isEmpty()) { + throw new IllegalArgumentException("source stack does not contain any matching z values to be corrected"); + } + + // set up target stack on the driver + final StackMetaData sourceStackMetaData = sourceDataClient.getStackMetaData(sourceStack); + sourceDataClient.setupDerivedStack(sourceStackMetaData, targetStack); + + // Phase 1: process tiles and collect corrections + LOG.info("correctCreep: {}, phase 1 - distributing {} z values for tile correction", + sourceStackDevString, correctedZValues.size()); + + final JavaRDD rddCorrectedZValues = sparkContext.parallelize(correctedZValues); + + final JavaRDD> rddResults = + rddCorrectedZValues.map(z -> processSingleLayer( + baseDataUrl, + sourceStackId, + creepCorrection.targetStackSuffix, + z, + matchCollectionId)); + final List> resultList = rddResults.collect(); + + // collect all corrections on the driver + final Map> allResults = new HashMap<>(); + for (int i = 0; i < correctedZValues.size(); i++) { + allResults.put(String.valueOf(correctedZValues.get(i).doubleValue()), resultList.get(i)); + } + + final List uncorrectedZValues = creepCorrection.buildUncorrectedZValues(stackWithAllZ.getzValues()); + if (! uncorrectedZValues.isEmpty()) { + + LOG.info("correctCreep: {}, phase 1 - distributing {} z values for simple copy", + sourceStackDevString, uncorrectedZValues.size()); + + final JavaRDD rddUncorrectedZValues = sparkContext.parallelize(uncorrectedZValues); + final JavaRDD rddCopyResults = + rddUncorrectedZValues.map(z -> copySingleLayer( + baseDataUrl, + sourceStackId, + creepCorrection.targetStackSuffix, + z)); + rddCopyResults.collect(); + } + + LOG.info("correctCreep: {}, phase 1 complete - processed {} z-layers", + sourceStackDevString, correctedZValues.size()); + + // write or log results on the driver + reportResults(creepCorrection, sourceStackId, allResults); + + // complete target stack on the driver + sourceDataClient.setStackState(targetStack, StackMetaData.StackState.COMPLETE); + + // Phase 2: transform matches + if (! creepCorrection.skipMatchCorrection) { + transformMatches(sparkContext, baseDataUrl, matchCollectionId, targetMatchCollectionId, allResults); + } else { + LOG.info("correctCreep: skipping match correction"); + } + + LOG.info("correctCreep: exit"); + } + + private void transformMatches(final JavaSparkContext sparkContext, + final String baseDataUrl, + final MatchCollectionId matchCollectionId, + final MatchCollectionId targetMatchCollectionId, + final Map> allResults) + throws IOException { + + final RenderDataClient driverMatchClient = new RenderDataClient(baseDataUrl, + matchCollectionId.getOwner(), + matchCollectionId.getName()); + + final List pGroupIds = driverMatchClient.getMatchPGroupIds(); + + if (pGroupIds.isEmpty()) { + LOG.info("transformMatches: no match groups found, skipping"); + return; + } + + LOG.info("transformMatches: Phase 2 - distributing {} match groups for coordinate transformation", + pGroupIds.size()); + + final Broadcast>> broadcastResults = sparkContext.broadcast(allResults); + + final JavaRDD rddGroupIds = sparkContext.parallelize(pGroupIds); + + rddGroupIds.foreach(groupId -> + transformMatchesForSingleGroup(baseDataUrl, + matchCollectionId, + targetMatchCollectionId, + groupId, + broadcastResults.value())); + + LOG.info("transformMatches: Phase 2 complete - transformed matches for {} groups", + pGroupIds.size()); + } + + private List processSingleLayer(final String baseDataUrl, + final StackId stackId, + final String targetStackSuffix, + final Double z, + final MatchCollectionId matchCollectionId) throws IOException { + + LogUtilities.setupExecutorLog4j(stackId.toDevString() + "::z" + z); + + final RenderDataClient executorRenderClient = new RenderDataClient(baseDataUrl, + stackId.getOwner(), + stackId.getProject()); + final RenderDataClient executorMatchClient = new RenderDataClient(baseDataUrl, + matchCollectionId.getOwner(), + matchCollectionId.getName()); + + final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); + final String targetStack = stackId.getStack() + targetStackSuffix; + return correctionClient.processZLayer(z, + executorRenderClient, + executorMatchClient, + stackId.getStack(), + targetStack); + } + + private Double copySingleLayer(final String baseDataUrl, + final StackId stackId, + final String targetStackSuffix, + final Double z) throws IOException { + + LogUtilities.setupExecutorLog4j(stackId.toDevString() + "::z" + z); + + final RenderDataClient executorRenderClient = new RenderDataClient(baseDataUrl, + stackId.getOwner(), + stackId.getProject()); + final String targetStack = stackId.getStack() + targetStackSuffix; + + final ResolvedTileSpecCollection resolvedTiles = executorRenderClient.getResolvedTiles(stackId.getStack(), z); + executorRenderClient.saveResolvedTiles(resolvedTiles, targetStack, z); + + return z; + } + + private void transformMatchesForSingleGroup(final String baseDataUrl, + final MatchCollectionId matchCollectionId, + final MatchCollectionId targetMatchCollectionId, + final String groupId, + final Map> allResults) + throws IOException { + + LogUtilities.setupExecutorLog4j(matchCollectionId.toDevString() + "::" + groupId); + + final RenderDataClient sourceMatchClient = new RenderDataClient(baseDataUrl, + matchCollectionId.getOwner(), + matchCollectionId.getName()); + final RenderDataClient targetMatchClient = new RenderDataClient(baseDataUrl, + targetMatchCollectionId.getOwner(), + targetMatchCollectionId.getName()); + + final CreepCorrectionClient correctionClient = new CreepCorrectionClient(); + correctionClient.transformMatchesForGroup(groupId, + allResults, + sourceMatchClient, + targetMatchClient); + } + + private void reportResults(final CreepCorrectionParameters creepCorrection, + final StackId sourceStackId, + final Map> allResults) { + + final List sortedScans = new ArrayList<>(allResults.keySet()); + sortedScans.sort(Comparator.comparingDouble(Double::parseDouble)); + + boolean written = false; + if (creepCorrection.parameterCsvDir != null) { + + final File csvFile = new File(creepCorrection.parameterCsvDir, + "creep-correction-param." + sourceStackId.getStack() + ".csv"); + + try (final PrintWriter writer = new PrintWriter(csvFile)) { + writer.println("scan," + MfovResult.CSV_HEADER); + for (final String scan : sortedScans) { + for (final MfovResult result : allResults.get(scan)) { + writer.println(scan + "," + result.toCsvRow()); + } + } + written = true; + LOG.info("reportResults: wrote creep correction parameters to {}", csvFile); + } catch (final Exception e) { + LOG.error("reportResults: failed to write CSV to {}, logging instead", csvFile, e); + } + } + + if (!written) { + LOG.info("reportResults: scan,{}", MfovResult.CSV_HEADER); + for (final String scan : sortedScans) { + for (final MfovResult result : allResults.get(scan)) { + LOG.info("reportResults: {},{}", scan, result.toCsvRow()); + } + } + } + } + + private static final Logger LOG = LoggerFactory.getLogger(CreepCorrectionSparkClient.class); +} diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java index 3e0951e88..04a1662d9 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineParameters.java @@ -16,6 +16,7 @@ import org.janelia.alignment.util.UrlResourceUtil; import org.janelia.render.client.newsolver.setup.AffineBlockSolverSetup; import org.janelia.render.client.newsolver.setup.IntensityCorrectionSetup; +import org.janelia.render.client.parameter.CreepCorrectionParameters; import org.janelia.render.client.parameter.MFOVAsTileParameters; import org.janelia.render.client.parameter.MFOVMontageMatchPatchParameters; import org.janelia.render.client.parameter.MaskHackParameters; @@ -48,6 +49,7 @@ public class AlignmentPipelineParameters private final UnconnectedCrossMFOVParameters unconnectedCrossMfov; private final TileClusterParameters tileCluster; private final MatchCopyParameters matchCopy; + private final CreepCorrectionParameters creepCorrection; private final AffineBlockSolverSetup affineBlockSolverSetup; private final IntensityCorrectionSetup intensityCorrectionSetup; private final ZSpacingParameters zSpacing; @@ -73,6 +75,7 @@ public AlignmentPipelineParameters() { null, null, null, + null, null); } @@ -85,6 +88,7 @@ public AlignmentPipelineParameters(final MultiProjectParameters multiProject, final UnconnectedCrossMFOVParameters unconnectedCrossMfov, final TileClusterParameters tileCluster, final MatchCopyParameters matchCopy, + final CreepCorrectionParameters creepCorrection, final AffineBlockSolverSetup affineBlockSolverSetup, final IntensityCorrectionSetup intensityCorrectionSetup, final ZSpacingParameters zSpacing, @@ -101,6 +105,7 @@ public AlignmentPipelineParameters(final MultiProjectParameters multiProject, this.unconnectedCrossMfov = unconnectedCrossMfov; this.tileCluster = tileCluster; this.matchCopy = matchCopy; + this.creepCorrection = creepCorrection; this.affineBlockSolverSetup = affineBlockSolverSetup; this.intensityCorrectionSetup = intensityCorrectionSetup; this.zSpacing = zSpacing; @@ -156,6 +161,10 @@ public MatchCopyParameters getMatchCopy() { return matchCopy; } + public CreepCorrectionParameters getCreepCorrection() { + return creepCorrection; + } + public String getMatchCopyToCollectionSuffix() { return matchCopy == null ? "" : matchCopy.toCollectionSuffix; } diff --git a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java index 0af0b09b8..4f1627590 100644 --- a/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java +++ b/render-ws-spark-client/src/main/java/org/janelia/render/client/spark/pipeline/AlignmentPipelineStepId.java @@ -8,6 +8,7 @@ import org.janelia.render.client.spark.match.ClusterCountClient; import org.janelia.render.client.spark.match.CopyMatchClient; import org.janelia.render.client.spark.match.MultiStagePointMatchClient; +import org.janelia.render.client.spark.multisem.CreepCorrectionSparkClient; import org.janelia.render.client.spark.multisem.MFOVASTileClient; import org.janelia.render.client.spark.multisem.MFOVMontageMatchPatchClient; import org.janelia.render.client.spark.multisem.UnconnectedCrossMFOVClient; @@ -29,6 +30,7 @@ public enum AlignmentPipelineStepId { FIND_UNCONNECTED_CROSS_MFOVS(UnconnectedCrossMFOVClient::new), FIND_UNCONNECTED_TILES_AND_EDGES(ClusterCountClient::new), FILTER_MATCHES(CopyMatchClient::new), + CORRECT_CREEP(CreepCorrectionSparkClient::new), ALIGN_TILES(DistributedAffineBlockSolverClient::new), CORRECT_Z_POSITIONS(ZPositionCorrectionClient::new), CORRECT_INTENSITY(DistributedIntensityCorrectionBlockSolverClient::new),