From 3f40f417ebe264cfcb4da1580dcc57eeec76e4ac Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 15 May 2026 17:04:00 +0000 Subject: [PATCH 1/3] [SYSTEMDS-2651] Replace fixed-sleep federated worker startup with TCP port polling Federated tests previously slept for FED_WORKER_WAIT (or FED_WORKER_WAIT_S) after launching each federated worker, racing against the worker's actual bind and wasting time once it was ready. With many tests starting multiple workers serially, this added significant fixed overhead per test. Replace with a TCP connect-poll on the worker's port, which becomes connectable only after Netty's bind().sync() has fully completed. Tests now return as soon as the worker is genuinely ready and fail fast if it never comes up. - Add FederatedTestUtils.waitForWorker(...) overloads that fail fast if the worker process/thread dies during startup. - Use it from startLocalFedWorker, startLocalFedWorkerThread, and startLocalFedWorkerWithArgs in AutomatedTestBase. - Keep the legacy int parameter on the public methods for source compatibility; reinterpret it as an upper-bound timeout with a sane floor so tiny historical values (e.g. 50 ms) don't time out on slow CI. Production code is untouched; the readiness check lives entirely in the test tree. --- .../apache/sysds/test/AutomatedTestBase.java | 127 ++++++++++-------- .../apache/sysds/test/FederatedTestUtils.java | 120 +++++++++++++++++ 2 files changed, 192 insertions(+), 55 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/FederatedTestUtils.java diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index c7f62b02a2b..9d043bd9da0 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -1665,14 +1665,21 @@ protected Process startLocalFedWorker(int port, String[] addArgs) { /** * Start new JVM for a federated worker at the port. - * - * @param port Port to use for the JVM - * @param addArgs The arguments to add - * @param sleep The time to wait for the process to start + * + *

Returns once the worker's TCP port is accepting connections (the + * worker only opens the port after Netty's bind has fully completed) or + * after {@code timeoutMs} elapses, in which case a {@link RuntimeException} + * is thrown. Replaces the previous fixed sleep so callers don't race against + * a worker that hasn't yet bound, and don't waste time after one that has. + * + * @param port Port to use for the JVM + * @param addArgs The arguments to add + * @param timeoutMs Upper bound on the wait for the worker to become ready + * (kept for source compatibility with the previous "sleep" + * parameter; clamped to a sane floor inside). * @return the process associated with the worker. */ - protected static Process startLocalFedWorker(int port, String[] addArgs, int sleep) { - Process process = null; + protected static Process startLocalFedWorker(int port, String[] addArgs, int timeoutMs) { String separator = System.getProperty("file.separator"); String classpath = System.getProperty("java.class.path"); String path = System.getProperty("java.home") + separator + "bin" + separator + "java"; @@ -1701,18 +1708,17 @@ protected static Process startLocalFedWorker(int port, String[] addArgs, int sle DMLScript.class.getName(), "-w", Integer.toString(port), "-stats"}); if(addArgs != null) args = ArrayUtils.addAll(args, addArgs); - + ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO(); + Process process; try { process = processBuilder.start(); - // Give some time to startup the worker. - sleep(sleep); } - catch(IOException | InterruptedException e) { - e.printStackTrace(); + catch(IOException e) { + throw new RuntimeException("Failed to launch federated worker process on port " + port, e); } - isAlive(process); + FederatedTestUtils.waitForWorker(process, port, timeoutMs); return process; } @@ -1782,50 +1788,46 @@ public static Thread startLocalFedWorkerThread(int port, int sleep) { } /** - * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! - * - * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. - * + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared. + * + *

Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is + * disabled. + * + *

Returns once the worker's TCP port is accepting connections (the worker only opens the port after + * Netty's bind has fully completed) or after {@code timeoutMs} elapses, in which case a + * {@link RuntimeException} is thrown. + * * @param port Port to use * @param otherArgs The command line arguments to start the worker with - * @param sleep The amount of time to wait for the worker startup. in Milliseconds + * @param timeoutMs Upper bound on the wait for the worker to become ready (kept for source compatibility + * with the previous "sleep" parameter; clamped to a sane floor inside). * @return The thread associated with the worker. */ - public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) { + public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int timeoutMs) { ArrayList args = new ArrayList<>(); - + args.add("-w"); args.add(Integer.toString(port)); if(otherArgs != null) - for( String s : otherArgs) + for(String s : otherArgs) args.add(s); String[] finalArguments = args.toArray(new String[args.size()]); Statistics.allowWorkerStatistics = false; - try { - Thread t = new Thread(() -> { - try { - main(finalArguments); - } - catch(Exception e) { - LOG.error("Exception in startup of federated worker", e); - } - }); - t.start(); - java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep); - if(!t.isAlive()) - throw new RuntimeException("Failed starting federated worker"); - return t; - } - catch(InterruptedException e) { - e.printStackTrace(); - fail("Failed to start federated worker : " + e); - // should never happen - return null; - } + Thread t = new Thread(() -> { + try { + main(finalArguments); + } + catch(Exception e) { + LOG.error("Exception in startup of federated worker", e); + } + }); + t.start(); + FederatedTestUtils.waitForWorker(t, port, timeoutMs); + return t; } public static boolean isAlive(Thread... threads){ @@ -1846,28 +1848,43 @@ public static boolean isAlive(Process... processes) { /** * Start java worker in same JVM. - * + * + *

Returns once the worker's TCP port is accepting connections (the worker only opens the port after + * Netty's bind has fully completed) or after the default federated worker timeout elapses, in which case + * a {@link RuntimeException} is thrown. The port is extracted from {@code args} (expects {@code "-w" }). + * * @param args the command line arguments - * @return the thread associated with the process.s + * @return the thread associated with the worker. */ public static Thread startLocalFedWorkerWithArgs(String[] args) { - Thread t = null; + final int port = extractWorkerPort(args); + Thread t = new Thread(() -> { + try { + main(args); + } + catch(IOException e) { + // suppressed: matches previous behaviour of this helper + } + }); + t.start(); + FederatedTestUtils.waitForWorker(t, port, FED_WORKER_WAIT); + return t; + } - try { - t = new Thread(() -> { + private static int extractWorkerPort(String[] args) { + for(int i = 0; i < args.length - 1; i++) { + if("-w".equals(args[i])) { try { - main(args); + return Integer.parseInt(args[i + 1]); } - catch(IOException e) { + catch(NumberFormatException e) { + throw new IllegalArgumentException( + "Federated worker args contain non-numeric port after -w: " + args[i + 1], e); } - }); - t.start(); - java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT); - } - catch(InterruptedException e) { - // Should happen at closing of the worker so don't print + } } - return t; + throw new IllegalArgumentException("Federated worker args must contain '-w ': " + + Arrays.toString(args)); } private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) { diff --git a/src/test/java/org/apache/sysds/test/FederatedTestUtils.java b/src/test/java/org/apache/sysds/test/FederatedTestUtils.java new file mode 100644 index 00000000000..184440cfb6e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/FederatedTestUtils.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.function.BooleanSupplier; + +/** + * Test-only helpers for waiting on federated worker startup. + * + *

The federated worker only opens its TCP port once Netty's + * {@code bind().sync()} has fully completed, so a successful TCP connect to + * that port is a precise readiness signal. Tests use this instead of fixed + * sleeps so they finish as soon as the worker is up and fail fast if the + * worker never comes up (or dies during startup). + * + *

This lives in the test tree on purpose; the production + * {@code FederatedWorker} carries no test-only callback hooks. + */ +public final class FederatedTestUtils { + + /** How often to retry the TCP connect while waiting (ms). */ + private static final int POLL_INTERVAL_MS = 25; + + /** + * Lower clamp on the caller-supplied timeout. Several historical call sites + * pass values like {@code FED_WORKER_WAIT_S = 50}, which were intended as a + * "short fixed sleep" before the next worker was started in parallel. With + * polling we exit as soon as the worker is ready, so only the upper bound + * matters; clamping prevents those tiny legacy values from turning into + * false-positive timeouts on slow CI machines. + */ + private static final int MIN_TIMEOUT_MS = 30_000; + + private FederatedTestUtils() { + // utility class + } + + /** + * Block until a federated worker is accepting TCP connections on + * {@code port}, or throw a {@link RuntimeException} after the timeout. + * + * @param port port the federated worker is expected to bind + * @param timeoutMs upper bound on the wait, in ms (clamped to a sane floor) + */ + public static void waitForWorker(int port, int timeoutMs) { + waitForWorker(port, timeoutMs, () -> true, "worker"); + } + + /** + * Block until a federated worker is accepting TCP connections on + * {@code port}. Fails fast if the supplied {@code aliveCheck} reports the + * worker is no longer alive. + */ + public static void waitForWorker(int port, int timeoutMs, BooleanSupplier aliveCheck, String workerKind) { + final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS); + final long deadline = System.currentTimeMillis() + effectiveTimeout; + while(System.currentTimeMillis() < deadline) { + if(!aliveCheck.getAsBoolean()) { + throw new RuntimeException( + "Federated " + workerKind + " on port " + port + " died before becoming ready."); + } + if(tryConnect(port)) { + return; + } + sleepQuietly(); + } + throw new RuntimeException("Federated " + workerKind + " on port " + port + + " did not become ready within " + effectiveTimeout + "ms."); + } + + /** Convenience: also fail fast if the worker process dies during startup. */ + public static void waitForWorker(Process process, int port, int timeoutMs) { + waitForWorker(port, timeoutMs, process::isAlive, "worker process"); + } + + /** Convenience: also fail fast if the worker thread dies during startup. */ + public static void waitForWorker(Thread thread, int port, int timeoutMs) { + waitForWorker(port, timeoutMs, thread::isAlive, "worker thread"); + } + + private static boolean tryConnect(int port) { + try(Socket s = new Socket()) { + s.connect(new InetSocketAddress("localhost", port), POLL_INTERVAL_MS); + return true; + } + catch(IOException e) { + return false; + } + } + + private static void sleepQuietly() { + try { + Thread.sleep(POLL_INTERVAL_MS); + } + catch(InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for federated worker", ie); + } + } +} From dffe6fff5d5c11b36e565543f648a1d7355e226f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 15 May 2026 17:57:07 +0000 Subject: [PATCH 2/3] [SYSTEMDS-2651] Bulk-start federated workers in tests instead of per-worker waits Many federated tests start 2-4 workers and then wait for each one before spawning the next. Even with TCP port-polling readiness, that serialises the JVM warm-up and dominates wall-clock time per test. Switch to a bulk pattern: spawn all workers back-to-back, then wait for them all in one shared poll loop, so total startup is bounded by the slowest worker rather than the sum of all of them. - AutomatedTestBase: factor out spawnLocalFedWorker / spawnLocalFedWorkerThread primitives, add startLocalFedWorkers / startLocalFedWorkerThreads bulk helpers returning Process[] / Thread[]. - FederatedTestUtils: add waitForWorkers(int[], ...) variants that poll all ports in a single loop and fail fast if any worker dies during startup. - Add a small (25 ms) stagger between in-JVM thread spawns to avoid races on shared static state during DMLScript / FederatedWorker init that the previous per-thread sleep was implicitly providing. - Migrate ~70 federated test files (primitives, algorithms, codegen, fedplanning, paramserv-base, lineage, transform, io) and MultiTenantTestBase.startFedWorkers to the bulk helper, including their isAlive(...) and TestUtils.shutdownThreads(...) varargs sites. Files with conditional / interleaved / single-worker startup are left alone. No production code is touched. --- .../apache/sysds/test/AutomatedTestBase.java | 107 ++++++++++++++++-- .../apache/sysds/test/FederatedTestUtils.java | 75 ++++++++++++ .../algorithms/FederatedAlsCGTest.java | 5 +- .../algorithms/FederatedBivarTest.java | 7 +- .../algorithms/FederatedCorTest.java | 7 +- .../algorithms/FederatedGLMTest.java | 5 +- .../algorithms/FederatedKmeansTest.java | 5 +- .../algorithms/FederatedL2SVMTest.java | 5 +- .../algorithms/FederatedLmPipeline.java | 7 +- .../algorithms/FederatedLogRegTest.java | 5 +- .../algorithms/FederatedMSVMTest.java | 5 +- .../algorithms/FederatedPCATest.java | 7 +- .../algorithms/FederatedPNMFTest.java | 5 +- .../algorithms/FederatedUnivarTest.java | 7 +- .../algorithms/FederatedVarTest.java | 7 +- .../algorithms/FederatedYL2SVMTest.java | 5 +- .../codegen/FederatedCellwiseTmplTest.java | 5 +- .../FederatedCodegenMultipleFedMOTest.java | 5 +- .../codegen/FederatedMultiAggTmplTest.java | 5 +- .../FederatedOuterProductTmplTest.java | 5 +- .../codegen/FederatedRowwiseTmplTest.java | 5 +- .../FederatedDynamicPlanningTest.java | 7 +- .../FederatedKMeansPlanningTest.java | 7 +- .../FederatedL2SVMPlanningTest.java | 7 +- .../FederatedMultiplyPlanningTest.java | 7 +- .../federated/io/FederatedReaderTest.java | 5 +- .../federated/io/FederatedSSLTest.java | 5 +- .../io/FederatedSparsityPropagationTest.java | 5 +- .../federated/io/FederatedWriterTest.java | 5 +- .../multitenant/MultiTenantTestBase.java | 13 +-- .../part1/FederatedBinaryMatrixTest.java | 7 +- .../part1/FederatedBinaryVectorTest.java | 7 +- .../part1/FederatedBroadcastTest.java | 7 +- .../part1/FederatedCastToFrameTest.java | 8 +- .../part1/FederatedCastToMatrixTest.java | 7 +- .../part1/FederatedCentralMomentTest.java | 10 +- .../part1/FederatedColAggregateTest.java | 9 +- .../part1/FederatedLeftIndexTest.java | 9 +- .../part1/FederatedMisAlignedTest.java | 10 +- .../part2/FederatedMultiplyTest.java | 7 +- .../primitives/part2/FederatedProdTest.java | 9 +- .../part2/FederatedQuantileTest.java | 11 +- .../part2/FederatedQuantileWeightsTest.java | 11 +- .../primitives/part2/FederatedRCBindTest.java | 9 +- .../primitives/part2/FederatedRdiagTest.java | 9 +- .../part2/FederatedRemoveEmptyTest.java | 9 +- .../part2/FederatedReplaceTest.java | 9 +- .../part2/FederatedReshapeTest.java | 9 +- .../primitives/part2/FederatedRevTest.java | 9 +- .../part2/FederatedRightIndexTest.java | 9 +- .../primitives/part2/FederatedRollTest.java | 9 +- .../part2/FederatedRowIndexTest.java | 9 +- .../primitives/part3/FederatedSplitTest.java | 7 +- .../part3/FederatedStatisticsTest.java | 7 +- .../part3/FederatedTokenizeTest.java | 8 +- .../part3/FederatedTransferLocalDataTest.java | 9 +- .../primitives/part3/FederatedTriTest.java | 9 +- .../FederatedWeightedCrossEntropyTest.java | 7 +- .../FederatedWeightedDivMatrixMultTest.java | 7 +- .../part3/FederatedWeightedSigmoidTest.java | 7 +- .../FederatedWeightedSquaredLossTest.java | 7 +- .../FederatedWeightedUnaryMatrixMultTest.java | 7 +- .../part4/FederatedRowAggregateTest.java | 9 +- .../part5/FederatedCovarianceTest.java | 18 +-- .../primitives/part5/FederatedCtableTest.java | 9 +- .../part5/FederatedFrameMapTest.java | 9 +- .../part5/FederatedFullAggregateTest.java | 9 +- .../part5/FederatedFullCumulativeTest.java | 9 +- .../primitives/part5/FederatedIfelseTest.java | 9 +- .../part5/FederatedMMChainTest.java | 9 +- .../TransformFederatedEncodeApplyTest.java | 9 +- .../TransformFederatedEncodeDecodeTest.java | 9 +- .../functions/lineage/FedFullReuseTest.java | 5 +- .../functions/lineage/FedUDFReuseTest.java | 7 +- .../functions/lineage/LineageFedReuseAlg.java | 5 +- 75 files changed, 378 insertions(+), 356 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 9d043bd9da0..ea8df42a34d 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -1680,10 +1680,52 @@ protected Process startLocalFedWorker(int port, String[] addArgs) { * @return the process associated with the worker. */ protected static Process startLocalFedWorker(int port, String[] addArgs, int timeoutMs) { + Process process = spawnLocalFedWorker(port, addArgs); + FederatedTestUtils.waitForWorker(process, port, timeoutMs); + return process; + } + + /** + * Start N federated worker JVMs in parallel, then wait for them all to be ready in one shared poll loop. + * + *

Wall-clock cost is bounded by the slowest worker's startup -- not the sum -- so this is the right + * helper for tests that need multiple federated workers. + * + * @param ports Ports to use, one per worker + * @return The process per port, in the same order as {@code ports}. + */ + protected static Process[] startLocalFedWorkers(int[] ports) { + return startLocalFedWorkers(ports, null, FED_WORKER_WAIT); + } + + /** @see #startLocalFedWorkers(int[], String[], int) */ + protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs) { + return startLocalFedWorkers(ports, addArgs, FED_WORKER_WAIT); + } + + /** + * Start N federated worker JVMs in parallel, then wait for them all to be ready. + * + * @param ports Ports to use, one per worker + * @param addArgs Extra worker CLI args (applied to every worker), or null + * @param timeoutMs Upper bound on the wait, in ms (clamped to a sane floor) + * @return The process per port, in the same order as {@code ports}. + */ + protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs, int timeoutMs) { + Process[] processes = new Process[ports.length]; + for(int i = 0; i < ports.length; i++) { + processes[i] = spawnLocalFedWorker(ports[i], addArgs); + } + FederatedTestUtils.waitForWorkers(processes, ports, timeoutMs); + return processes; + } + + /** Spawn a federated worker JVM and return immediately, without waiting for the port to bind. */ + private static Process spawnLocalFedWorker(int port, String[] addArgs) { String separator = System.getProperty("file.separator"); String classpath = System.getProperty("java.class.path"); String path = System.getProperty("java.home") + separator + "bin" + separator + "java"; - String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m", + String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m", "--add-opens=java.base/java.nio=ALL-UNNAMED" , "--add-opens=java.base/java.io=ALL-UNNAMED" , "--add-opens=java.base/java.util=ALL-UNNAMED" , @@ -1710,16 +1752,12 @@ protected static Process startLocalFedWorker(int port, String[] addArgs, int tim args = ArrayUtils.addAll(args, addArgs); ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO(); - - Process process; try { - process = processBuilder.start(); + return processBuilder.start(); } catch(IOException e) { throw new RuntimeException("Failed to launch federated worker process on port " + port, e); } - FederatedTestUtils.waitForWorker(process, port, timeoutMs); - return process; } /** @@ -1804,12 +1842,64 @@ public static Thread startLocalFedWorkerThread(int port, int sleep) { * @return The thread associated with the worker. */ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int timeoutMs) { + Thread t = spawnLocalFedWorkerThread(port, otherArgs); + FederatedTestUtils.waitForWorker(t, port, timeoutMs); + return t; + } - ArrayList args = new ArrayList<>(); + /** + * Start N federated worker threads in the same JVM in parallel, then wait for them all to be ready in + * one shared poll loop. + * + * @param ports Ports to use, one per worker + * @return The thread per port, in the same order as {@code ports}. + */ + public static Thread[] startLocalFedWorkerThreads(int[] ports) { + return startLocalFedWorkerThreads(ports, null, FED_WORKER_WAIT); + } + + /** @see #startLocalFedWorkerThreads(int[], String[], int) */ + public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs) { + return startLocalFedWorkerThreads(ports, otherArgs, FED_WORKER_WAIT); + } + + /** + * Start N federated worker threads in the same JVM in parallel, then wait for them all to be ready. + * + * @param ports Ports to use, one per worker + * @param otherArgs Extra worker CLI args (applied to every worker), or null + * @param timeoutMs Upper bound on the wait, in ms (clamped to a sane floor) + * @return The thread per port, in the same order as {@code ports}. + */ + public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs, int timeoutMs) { + Thread[] threads = new Thread[ports.length]; + for(int i = 0; i < ports.length; i++) { + threads[i] = spawnLocalFedWorkerThread(ports[i], otherArgs); + // Tiny stagger between in-JVM thread spawns: lets each worker pull through the shared + // DMLScript / FederatedWorker static-init (LineageCacheConfig setters etc.) without all + // threads racing through the same global state simultaneously. Cheap relative to total + // startup; restores the implicit ordering the old per-thread sleep provided. + if(i + 1 < ports.length) { + try { + java.util.concurrent.TimeUnit.MILLISECONDS.sleep(THREAD_SPAWN_STAGGER_MS); + } + catch(InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while spawning federated worker threads", e); + } + } + } + FederatedTestUtils.waitForWorkers(threads, ports, timeoutMs); + return threads; + } + + private static final int THREAD_SPAWN_STAGGER_MS = 25; + /** Spawn a federated worker thread in this JVM and return immediately, without waiting for the port to bind. */ + private static Thread spawnLocalFedWorkerThread(int port, String[] otherArgs) { + ArrayList args = new ArrayList<>(); args.add("-w"); args.add(Integer.toString(port)); - if(otherArgs != null) for(String s : otherArgs) args.add(s); @@ -1826,7 +1916,6 @@ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int } }); t.start(); - FederatedTestUtils.waitForWorker(t, port, timeoutMs); return t; } diff --git a/src/test/java/org/apache/sysds/test/FederatedTestUtils.java b/src/test/java/org/apache/sysds/test/FederatedTestUtils.java index 184440cfb6e..d328fdee69e 100644 --- a/src/test/java/org/apache/sysds/test/FederatedTestUtils.java +++ b/src/test/java/org/apache/sysds/test/FederatedTestUtils.java @@ -98,6 +98,81 @@ public static void waitForWorker(Thread thread, int port, int timeoutMs) { waitForWorker(port, timeoutMs, thread::isAlive, "worker thread"); } + /** + * Block until all the listed federated workers are accepting TCP + * connections. Polls all ports in a single shared loop, so when N workers + * are spawned in parallel and then this is called, total wall-clock wait is + * bounded by the slowest worker -- not the sum of individual startup times. + * + * @param ports ports the workers are expected to bind + * @param timeoutMs upper bound on the wait, in ms (clamped to a sane floor) + */ + public static void waitForWorkers(int[] ports, int timeoutMs) { + waitForWorkers(ports, timeoutMs, i -> true, "workers"); + } + + /** Convenience: also fail fast if any of the worker processes dies during startup. */ + public static void waitForWorkers(Process[] processes, int[] ports, int timeoutMs) { + if(processes.length != ports.length) { + throw new IllegalArgumentException( + "processes/ports length mismatch: " + processes.length + " vs " + ports.length); + } + waitForWorkers(ports, timeoutMs, i -> processes[i].isAlive(), "worker processes"); + } + + /** Convenience: also fail fast if any of the worker threads dies during startup. */ + public static void waitForWorkers(Thread[] threads, int[] ports, int timeoutMs) { + if(threads.length != ports.length) { + throw new IllegalArgumentException( + "threads/ports length mismatch: " + threads.length + " vs " + ports.length); + } + waitForWorkers(ports, timeoutMs, i -> threads[i].isAlive(), "worker threads"); + } + + /** + * Bulk variant: takes a per-index alive check so the caller can plug in + * {@code Process} or {@code Thread} liveness. Marks each port ready as soon + * as it accepts a connection, and only sleeps between rounds (not between + * port checks), so all ports race in parallel. + */ + public static void waitForWorkers(int[] ports, int timeoutMs, java.util.function.IntPredicate aliveCheck, + String workerKind) { + final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS); + final long deadline = System.currentTimeMillis() + effectiveTimeout; + final boolean[] ready = new boolean[ports.length]; + int remaining = ports.length; + while(remaining > 0 && System.currentTimeMillis() < deadline) { + boolean madeProgress = false; + for(int i = 0; i < ports.length; i++) { + if(ready[i]) { + continue; + } + if(!aliveCheck.test(i)) { + throw new RuntimeException("Federated " + workerKind + " on port " + ports[i] + + " died before becoming ready."); + } + if(tryConnect(ports[i])) { + ready[i] = true; + remaining--; + madeProgress = true; + } + } + if(remaining > 0 && !madeProgress) { + sleepQuietly(); + } + } + if(remaining > 0) { + StringBuilder sb = new StringBuilder("Federated ").append(workerKind) + .append(" did not all become ready within ").append(effectiveTimeout).append("ms. Pending ports:"); + for(int i = 0; i < ports.length; i++) { + if(!ready[i]) { + sb.append(' ').append(ports[i]); + } + } + throw new RuntimeException(sb.toString()); + } + } + private static boolean tryConnect(int port) { try(Socket s = new Socket()) { s.connect(new InetSocketAddress("localhost", port), POLL_INTERVAL_MS); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java index 5e880166afe..3b53c71c0ea 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java @@ -120,8 +120,7 @@ public void federatedAlsCG(String testname, ExecMode execMode) fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(testname); @@ -153,7 +152,7 @@ public void federatedAlsCG(String testname, ExecMode execMode) HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_!=")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java index f2f35700119..0537ecf18de 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java @@ -114,10 +114,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -139,7 +136,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check that federated input files are still existing Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java index 73bf8e91de6..2c81da6333e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java @@ -102,10 +102,7 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -140,7 +137,7 @@ private void runAggregateOperationTest(ExecMode execMode) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java index 6d8e8165305..b54cb3d0a64 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java @@ -100,8 +100,7 @@ public void federatedGLM(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -122,7 +121,7 @@ public void federatedGLM(Types.ExecMode execMode) { // compare via files compareResults(1e-2); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java index c8605ac3d88..c1fc83ae3c3 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java @@ -113,8 +113,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -155,7 +154,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) { // compare via files // compareResults(1e-9); --> randomized - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java index f7040e6a4dc..0b38317f73b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java @@ -104,8 +104,7 @@ public void federatedL2SVM(Types.ExecMode execMode, boolean singleWorker) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -126,7 +125,7 @@ public void federatedL2SVM(Types.ExecMode execMode, boolean singleWorker) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java index 8b02246e0aa..f3e921dd935 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java @@ -109,10 +109,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -134,7 +131,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE // compare via files compareResults(1e-2); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check correct federated operations diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java index a3e91ef37d7..74882637d95 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java @@ -95,8 +95,7 @@ public void federatedLogReg(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -117,7 +116,7 @@ public void federatedLogReg(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue("contains fed_ba+*", heavyHittersContainsString("fed_ba+*")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java index c5344f9d847..97a4d84a0b9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java @@ -97,8 +97,7 @@ public void federatedMSVM(Types.ExecMode execMode, boolean singleWorker) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -119,7 +118,7 @@ public void federatedMSVM(Types.ExecMode execMode, boolean singleWorker) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java index 2a4186680c5..9ddd59483ce 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java @@ -102,10 +102,7 @@ public void federatedPCA(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -127,7 +124,7 @@ public void federatedPCA(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java index 19fb72ce47b..183f9900a6d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java @@ -113,8 +113,7 @@ public void federatedPNMF(ExecMode execMode) fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -141,7 +140,7 @@ public void federatedPNMF(ExecMode execMode) HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_wcemm")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java index a4a8236e3a8..21f2e392250 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java @@ -100,10 +100,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -126,7 +123,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_uacmax")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java index db071220d00..445abfa1485 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java @@ -115,10 +115,7 @@ private void runAggregateOperationTest(ExecMode execMode, boolean federatedCompi int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) @@ -157,7 +154,7 @@ private void runAggregateOperationTest(ExecMode execMode, boolean federatedCompi Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); OptimizerUtils.FEDERATED_COMPILATION = false; rtplatform = platformOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java index b8eef26f35a..a87328a8aa4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java @@ -117,8 +117,7 @@ public void federatedL2SVM(Types.ExecMode execMode, String testName) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(testName); loadTestConfiguration(config); @@ -140,7 +139,7 @@ public void federatedL2SVM(Types.ExecMode execMode, String testName) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java index 2f6e5e46173..94fab208e09 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java @@ -158,8 +158,7 @@ private void testFederatedCodegen(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -188,7 +187,7 @@ private void testFederatedCodegen(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java index 61722dbc467..a23219e1587 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java @@ -194,8 +194,7 @@ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -228,7 +227,7 @@ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations if(test_num >= 0 && test_num < 100) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java index 33a551d28d0..74cd2c0a420 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java @@ -143,8 +143,7 @@ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -173,7 +172,7 @@ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java index cef5fd5e99f..dd08ae58636 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java @@ -146,8 +146,7 @@ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -176,7 +175,7 @@ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java index b47a718c0ca..185c15557d1 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java @@ -147,8 +147,7 @@ private void testFederatedCodegenRowwise(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -178,7 +177,7 @@ private void testFederatedCodegenRowwise(ExecMode exec_mode) { TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations if(row_partitioned) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java index 53dce3f01c7..bcc2b61adf2 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java @@ -118,7 +118,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - Thread t1 = null, t2 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(testName); @@ -128,8 +128,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -164,7 +163,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java index 9a9ff18d28b..b6c239854b3 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java @@ -125,7 +125,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - Thread t1 = null, t2 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(testName); @@ -135,8 +135,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -158,7 +157,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri // fail("The following expected heavy hitters are missing: " // + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java index 91143396208..39ab4901480 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java @@ -170,7 +170,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - Thread t1 = null, t2 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(testName); @@ -180,8 +180,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -205,7 +204,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri // + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java index 5b54f14d059..0a2bd230da4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java @@ -240,7 +240,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, if(rtplatform == Types.ExecMode.SPARK) { DMLScript.USE_LOCAL_SPARK_CONFIG = true; } - Thread t1 = null, t2 = null; + Thread[] workers = null; try{ getAndLoadTestConfiguration(testName); @@ -250,8 +250,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -275,7 +274,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, fail("The following expected heavy hitters are missing: " + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java index ff96ad8af2f..566d3ff323c 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java @@ -91,8 +91,7 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); String host = "localhost"; try { @@ -137,6 +136,6 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) { resetExecMode(oldPlatform); } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java index 5d4d2c4e9c3..5f5c09e07cb 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java @@ -96,8 +96,7 @@ public void federatedRead(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); String host = "localhost"; @@ -134,7 +133,7 @@ public void federatedRead(Types.ExecMode execMode) { resetExecMode(oldPlatform); } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } /** diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java index 42165d6b68f..f2654b9bb66 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java @@ -102,8 +102,7 @@ public void federatedGet(ExecMode execMode, double sparsity) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); getAndLoadTestConfiguration(TEST_NAME); @@ -136,7 +135,7 @@ public void federatedGet(ExecMode execMode, double sparsity) { compareNNZ(refNNZ, fedNNZ); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java index d8bb7431472..68fb3a7da63 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java @@ -83,8 +83,7 @@ public void federatedWrite(ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); try { @@ -122,6 +121,6 @@ public void federatedWrite(ExecMode execMode) { resetExecMode(oldPlatform); } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java index c3a4756a2d5..866bbe00139 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java @@ -38,8 +38,6 @@ import org.apache.sysds.test.AutomatedTestBase; import org.junit.After; -import com.google.crypto.tink.subtle.Random; - public abstract class MultiTenantTestBase extends AutomatedTestBase { protected static final Log LOG = LogFactory.getLog(MultiTenantTestBase.class.getName()); @@ -63,7 +61,8 @@ protected int[] startFedWorkers(int numFedWorkers) { } /** - * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses + * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses. + * Workers are spawned together and their port-bind is awaited in parallel. * * @param numFedWorkers the number of federated workers to start * @return int[] the ports of the created federated workers @@ -72,10 +71,10 @@ protected int[] startFedWorkers(int numFedWorkers, String[] addArgs) { int[] ports = new int[numFedWorkers]; for(int counter = 0; counter < numFedWorkers; counter++) { ports[counter] = getRandomAvailablePort(); - // start process but only wait long for last one. - Process tmpProcess = startLocalFedWorker(ports[counter], addArgs, - counter == numFedWorkers - 1 ? (FED_WORKER_WAIT + Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S); - workerProcesses.add(tmpProcess); + } + Process[] processes = startLocalFedWorkers(ports, addArgs); + for(Process p : processes) { + workerProcesses.add(p); } return ports; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java index 8991d28194f..20945303b0b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java @@ -105,11 +105,10 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -134,7 +133,7 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java index e038d9efda0..d47a7efcfa4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java @@ -97,11 +97,10 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -127,7 +126,7 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java index 434eb08e262..1a88f7a0d39 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java @@ -90,11 +90,10 @@ public void federatedBroadcast(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -117,7 +116,7 @@ public void federatedBroadcast(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java index a7bc0d80640..392e53aa77b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java @@ -97,12 +97,10 @@ public void federatedMultiply(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); - try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -127,7 +125,7 @@ public void federatedMultiply(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java index ccfd4a6c430..832ae98d8ed 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java @@ -122,10 +122,9 @@ public void federatedMultiply(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -152,7 +151,7 @@ public void federatedMultiply(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java index 1d9f951e78a..b52a1d60262 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java @@ -118,14 +118,10 @@ public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted) int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT + 1000); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}, null, FED_WORKER_WAIT + 1000); - try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // reference file should not be written to hdfs, so we set platform here @@ -187,7 +183,7 @@ public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted) } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java index f52a74bd4fc..5254378cf27 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java @@ -174,13 +174,10 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -244,7 +241,7 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java index 30d6b8fc0ec..6f9744ded1a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java @@ -165,13 +165,10 @@ else if(dataType == DataType.FRAME) int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -212,7 +209,7 @@ else if(dataType == DataType.FRAME) } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); } } finally { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java index 8fb8a80663e..06f86dd61b4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java @@ -220,14 +220,10 @@ private void runMisAlignedTest(OpType type, ExecMode execMode, MisAlignmentType int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); - try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // Run reference dml script with normal matrix @@ -281,7 +277,7 @@ private void runMisAlignedTest(OpType type, ExecMode execMode, MisAlignmentType } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java index 05e4954af15..cfdac1932ff 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java @@ -114,12 +114,11 @@ private void federatedMultiply(Types.ExecMode execMode, boolean federatedCompila int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -145,7 +144,7 @@ private void federatedMultiply(Types.ExecMode execMode, boolean federatedCompila } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java index 530a8b7b55e..2447143de2b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java @@ -108,14 +108,11 @@ private void runProdTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -154,7 +151,7 @@ private void runProdTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java index 25d460526e9..6f13fce81b2 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java @@ -164,12 +164,13 @@ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double port2 = getRandomAvailablePort(); port3 = getRandomAvailablePort(); port4 = getRandomAvailablePort(); - t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); + t1 = workers[0]; + t2 = workers[1]; + t3 = workers[2]; + t4 = workers[3]; - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); programArgs1 = new String[] {"-explain", "-stats", "100", "-args", String.valueOf(p), expected("S"), diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java index dadbd6b5905..ebf4708b3e9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java @@ -132,12 +132,13 @@ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double port2 = getRandomAvailablePort(); port3 = getRandomAvailablePort(); port4 = getRandomAvailablePort(); - t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); + t1 = workers[0]; + t2 = workers[1]; + t3 = workers[2]; + t4 = workers[3]; - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); programArgs1 = new String[] {"-explain", "-stats", "100", "-args", String.valueOf(p), expected("S"), diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java index d8ef8ca4c20..be69fac747d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java @@ -109,14 +109,11 @@ public void federatedRCBind(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // we need the reference file to not be written to hdfs, so we get the correct format @@ -158,7 +155,7 @@ public void federatedRCBind(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java index 46bf7a45650..ce6aba438e0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java @@ -112,14 +112,11 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); ProgramBlock.CHECK_MATRIX_PROPERTIES = true; try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // reference file should not be written to hdfs, so we set platform here @@ -159,7 +156,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java index 71f7c583665..0714a6e6fb7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java @@ -121,14 +121,11 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -166,7 +163,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java index d2ec98a4a8d..c5347d6276a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java @@ -108,13 +108,10 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -153,7 +150,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java index 5d6887c3e2d..fb8bc619ec5 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java @@ -101,14 +101,11 @@ public void federatedReshape(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // reference file should not be written to hdfs, so we set platform here rtplatform = execMode; @@ -146,7 +143,7 @@ public void federatedReshape(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java index e4b7ed5e24a..c4f30a574b0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java @@ -129,14 +129,11 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -176,7 +173,7 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java index b8fe21ef52d..7a72a868b6e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java @@ -178,13 +178,10 @@ private void runAggregateOperationTest(IndexType indexType, DataType dataType, E int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -236,7 +233,7 @@ private void runAggregateOperationTest(IndexType indexType, DataType dataType, E } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java index f242710338d..aa367ea7d4d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java @@ -131,14 +131,11 @@ private void runRollTest(ExecMode execMode, boolean activateFedCompilation) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if (!isAlive(t1, t2, t3, t4)) + if (!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if (rtplatform == ExecMode.SPARK) { @@ -177,7 +174,7 @@ private void runRollTest(ExecMode execMode, boolean activateFedCompilation) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java index 25b02ca1537..37b6f8826e4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java @@ -108,13 +108,10 @@ private void runRowIndexTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -155,7 +152,7 @@ private void runRowIndexTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java index fc9e4a73ba5..462181e75ea 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java @@ -106,11 +106,10 @@ public void federatedSplit(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // Run reference dml script with normal matrix @@ -141,7 +140,7 @@ else if(execMode != Types.ExecMode.SPARK) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java index ac9706a44ae..dec5fac4432 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java @@ -96,10 +96,9 @@ public void federatedLogReg(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -134,7 +133,7 @@ public void federatedLogReg(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java index f8acc4623a7..84b632c5e2c 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java @@ -91,12 +91,10 @@ private void runAggregateOperationTest(ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3}); try { - if(!isAlive(t1, t2, t3)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(false, DataExpression.DEFAULT_DELIM_DELIMITER, @@ -143,7 +141,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java index 0d29617eb76..2f2adf2bc93 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java @@ -88,13 +88,10 @@ private void runTransferTest(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -125,7 +122,7 @@ private void runTransferTest(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java index 63432e44658..837fb979fc6 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java @@ -108,12 +108,9 @@ private void runTriTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -151,7 +148,7 @@ private void runTriTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java index bd48a45924f..5e0953d0aa5 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java @@ -136,11 +136,10 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(testname); @@ -172,7 +171,7 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java index 6753774f653..3f866be9844 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java @@ -297,11 +297,10 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -334,7 +333,7 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java index ad21760d086..15a709a2057 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java @@ -161,11 +161,10 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -198,7 +197,7 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java index 7fac163c4a6..111c16ded84 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java @@ -151,11 +151,10 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -187,7 +186,7 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java index 8019edf83f8..610b5ebedf7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java @@ -160,11 +160,10 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -197,7 +196,7 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java index ba725519a8c..1088de53625 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java @@ -214,14 +214,11 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // Run reference dml script with normal matrix @@ -280,7 +277,7 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java index 48c9cab632a..0993b41698f 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java @@ -134,13 +134,10 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); setExecMode(execMode); @@ -203,7 +200,7 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } @@ -246,13 +243,10 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -364,7 +358,7 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java index 4ba19fb3d0f..ebb42dfec86 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java @@ -133,14 +133,11 @@ public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrix int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -154,7 +151,7 @@ public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrix } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java index 037675c584c..189c109fa8b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java @@ -101,14 +101,11 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -149,7 +146,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); } } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java index d75924746c5..006b5d76883 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java @@ -203,13 +203,10 @@ private void runAggregateOperationTest(OpType type, ExecType instType) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -264,7 +261,7 @@ private void runAggregateOperationTest(OpType type, ExecType instType) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java index 5eb1179efd9..b20d949857e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java @@ -168,14 +168,11 @@ private void runCumOperationTest(OpType type, ExecType instType) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -245,7 +242,7 @@ private void runCumOperationTest(OpType type, ExecType instType) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java index abf0c8c228f..375a4f95181 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java @@ -138,14 +138,11 @@ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean singleMa int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1); - Process t2 = startLocalFedWorker(port2); - Process t3 = startLocalFedWorker(port3); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -220,7 +217,7 @@ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean singleMa } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java index c37341ed681..73cb5470a38 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java @@ -138,14 +138,11 @@ private void runMMChainTest(ExecMode execMode, String TEST_NAME) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -181,7 +178,7 @@ private void runMMChainTest(ExecMode execMode, String TEST_NAME) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java index b70bdd2940b..fd1f88f2761 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java @@ -230,7 +230,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line default: throw new RuntimeException("Not supported type"); } - Thread t1 = null, t2 = null, t3 = null, t4 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(TEST_NAME1); @@ -239,10 +239,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); String[] otherargs = lineage ? new String[] {"-lineage", "reuse_full"} : null; - t1 = startLocalFedWorkerThread(port1, otherargs); - t2 = startLocalFedWorkerThread(port2, otherargs); - t3 = startLocalFedWorkerThread(port3, otherargs); - t4 = startLocalFedWorkerThread(port4, otherargs); + workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs); FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER, DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ? @@ -345,7 +342,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line throw new RuntimeException(ex); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(rtold); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java index f144e03984a..8c94e62692d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java @@ -127,7 +127,7 @@ public void runComplexRecodeTestBinarySparseCP() { private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.FileFormat format) { ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE); - Thread t1 = null, t2 = null, t3 = null, t4 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(TEST_NAME_RECODE); @@ -135,10 +135,7 @@ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types. int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - t4 = startLocalFedWorkerThread(port4); + workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, null, FED_WORKER_WAIT); // schema Types.ValueType[] schema = new Types.ValueType[cols / 2]; @@ -205,7 +202,7 @@ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types. Assert.fail(ex.getMessage()); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(rtold); } } diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java index ca133caa76e..0c46cd68ea5 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java @@ -103,8 +103,7 @@ public void federatedReuse(String test) { int port2 = getRandomAvailablePort(); String[] otherargs = new String[] {"-lineage", "reuse_full"}; Lineage.resetInternalState(); - Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, otherargs); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT); TestConfiguration config = availableTestConfigurations.get(test); loadTestConfiguration(config); @@ -149,7 +148,7 @@ public void federatedReuse(String test) { } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java index f7f01f6f9ef..0cf9d972719 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java @@ -108,10 +108,7 @@ private void runTriUDFReuse(ExecMode execMode) { int port4 = getRandomAvailablePort(); String[] otherargs = new String[] {"-lineage", "reuse_full"}; Lineage.resetInternalState(); - Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, otherargs, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, otherargs, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4, otherargs); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs, FED_WORKER_WAIT); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -146,7 +143,7 @@ private void runTriUDFReuse(ExecMode execMode) { // assert reuse count Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java index 15f6a7978e5..b8a66196670 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java @@ -93,8 +93,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); String[] otherargs = new String[] {"-lineage", "reuse_full"}; - Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, otherargs); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -136,7 +135,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE assertTrue(mmCount > mmCount_reuse); assertTrue(fed_mmCount > fed_mmCount_reuse); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } finally { resetExecMode(oldExec); From 20164d530227af7a18bbec66577256396e25fbc0 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 15 May 2026 18:31:10 +0000 Subject: [PATCH 3/3] [SYSTEMDS-2651] Rename federated worker readiness helper and tighten timeout clamp Rename org.apache.sysds.test.FederatedTestUtils to FederatedWorkerUtils to avoid the simple-name collision with the existing org.apache.sysds.test.component.federated.FederatedTestUtils, which covers unrelated federation RPC/data helpers. Lower the readiness-wait minimum clamp from 30s to 3s so a worker that fails to bind is reported within a few seconds of FED_WORKER_WAIT, instead of being held to a 30s floor before the timeout exception is raised. Also rewrite the Javadoc and inline comments on the new helpers in FederatedTestUtils -> FederatedWorkerUtils and on the federated worker helpers in AutomatedTestBase to a neutral, descriptive tone: - Document the MIN_TIMEOUT_MS clamp explicitly in every public overload rather than only inside the helper. - Split POLL_INTERVAL_MS into POLL_INTERVAL_MS (round sleep) and CONNECT_TIMEOUT_MS (per-attempt Socket.connect timeout) so the two uses are independently named. - Reword "start ... in parallel" to "start ... back to back, then wait for all of them in one shared poll loop" to match the actual sequential start + parallel readiness pattern. - Replace narration of the previous fixed-sleep API (and adjectives such as "tiny", "cheap", "sane floor", "right helper") with present-tense behavior contracts referencing the named constants. - Drop two stray "shared.!" typos in pre-existing thread-worker Javadoc. No behavior change beyond the 30s -> 3s clamp; mvn test-compile is clean and FederatedL2SVMTest (in-JVM threads) and FederatedTriTest (4 separate worker JVMs) both pass. --- .../apache/sysds/test/AutomatedTestBase.java | 93 ++++++++++--------- ...stUtils.java => FederatedWorkerUtils.java} | 81 ++++++++-------- 2 files changed, 91 insertions(+), 83 deletions(-) rename src/test/java/org/apache/sysds/test/{FederatedTestUtils.java => FederatedWorkerUtils.java} (64%) diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index ea8df42a34d..46b18cf3476 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -1642,9 +1642,11 @@ protected Process startLocalFedWorker(int port){ /** * Start a new JVM for a federated worker at the port. - * - * @param port Port to use for the JVM - * @param sleep The sleep time to wait for the worker to start + * + * @param port Port to use for the JVM + * @param sleep Upper bound on the wait for the worker to become ready, in ms; raised to a + * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name retained + * for source compatibility with the previous fixed-sleep API. * @return The process containing the worker */ protected Process startLocalFedWorker(int port, int sleep){ @@ -1666,30 +1668,26 @@ protected Process startLocalFedWorker(int port, String[] addArgs) { /** * Start new JVM for a federated worker at the port. * - *

Returns once the worker's TCP port is accepting connections (the - * worker only opens the port after Netty's bind has fully completed) or - * after {@code timeoutMs} elapses, in which case a {@link RuntimeException} - * is thrown. Replaces the previous fixed sleep so callers don't race against - * a worker that hasn't yet bound, and don't waste time after one that has. + *

Returns once the worker's TCP port accepts connections (the worker opens the port after + * Netty's bind completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses. * * @param port Port to use for the JVM * @param addArgs The arguments to add - * @param timeoutMs Upper bound on the wait for the worker to become ready - * (kept for source compatibility with the previous "sleep" - * parameter; clamped to a sane floor inside). + * @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a + * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name + * retained for source compatibility with the previous fixed-sleep API. * @return the process associated with the worker. */ protected static Process startLocalFedWorker(int port, String[] addArgs, int timeoutMs) { Process process = spawnLocalFedWorker(port, addArgs); - FederatedTestUtils.waitForWorker(process, port, timeoutMs); + FederatedWorkerUtils.waitForWorker(process, port, timeoutMs); return process; } /** - * Start N federated worker JVMs in parallel, then wait for them all to be ready in one shared poll loop. - * - *

Wall-clock cost is bounded by the slowest worker's startup -- not the sum -- so this is the right - * helper for tests that need multiple federated workers. + * Start N federated worker JVMs back to back, then wait for all of them to become ready in one + * shared poll loop. The wall-clock wait scales with the slowest worker rather than the sum of the + * per-worker waits. * * @param ports Ports to use, one per worker * @return The process per port, in the same order as {@code ports}. @@ -1704,11 +1702,13 @@ protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs) { } /** - * Start N federated worker JVMs in parallel, then wait for them all to be ready. + * Start N federated worker JVMs back to back, then wait for all of them to become ready in one + * shared poll loop. * * @param ports Ports to use, one per worker * @param addArgs Extra worker CLI args (applied to every worker), or null - * @param timeoutMs Upper bound on the wait, in ms (clamped to a sane floor) + * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside + * {@link FederatedWorkerUtils}. * @return The process per port, in the same order as {@code ports}. */ protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs, int timeoutMs) { @@ -1716,11 +1716,11 @@ protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs, i for(int i = 0; i < ports.length; i++) { processes[i] = spawnLocalFedWorker(ports[i], addArgs); } - FederatedTestUtils.waitForWorkers(processes, ports, timeoutMs); + FederatedWorkerUtils.waitForWorkers(processes, ports, timeoutMs); return processes; } - /** Spawn a federated worker JVM and return immediately, without waiting for the port to bind. */ + /** Spawn a federated worker JVM and return without waiting for the port to bind. */ private static Process spawnLocalFedWorker(int port, String[] addArgs) { String separator = System.getProperty("file.separator"); String classpath = System.getProperty("java.class.path"); @@ -1787,7 +1787,7 @@ protected Process startLocalFedMonitoring(int port, String[] addArgs) { } /** - * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared. * * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. * @@ -1813,12 +1813,14 @@ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs) { } /** - * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared. * * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. * * @param port Port to use - * @param sleep The amount of time to wait for the worker startup. in Milliseconds + * @param sleep Upper bound on the wait for the worker to become ready, in ms; raised to a minimum + * value enforced inside {@link FederatedWorkerUtils}. Parameter name retained for source + * compatibility with the previous fixed-sleep API. * @return The thread associated with the worker. */ public static Thread startLocalFedWorkerThread(int port, int sleep) { @@ -1831,25 +1833,26 @@ public static Thread startLocalFedWorkerThread(int port, int sleep) { *

Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is * disabled. * - *

Returns once the worker's TCP port is accepting connections (the worker only opens the port after - * Netty's bind has fully completed) or after {@code timeoutMs} elapses, in which case a - * {@link RuntimeException} is thrown. + *

Returns once the worker's TCP port accepts connections (the worker opens the port after Netty's bind + * completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses. * * @param port Port to use * @param otherArgs The command line arguments to start the worker with - * @param timeoutMs Upper bound on the wait for the worker to become ready (kept for source compatibility - * with the previous "sleep" parameter; clamped to a sane floor inside). + * @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a + * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name + * retained for source compatibility with the previous fixed-sleep API. * @return The thread associated with the worker. */ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int timeoutMs) { Thread t = spawnLocalFedWorkerThread(port, otherArgs); - FederatedTestUtils.waitForWorker(t, port, timeoutMs); + FederatedWorkerUtils.waitForWorker(t, port, timeoutMs); return t; } /** - * Start N federated worker threads in the same JVM in parallel, then wait for them all to be ready in - * one shared poll loop. + * Start N federated worker threads in the same JVM back to back, then wait for all of them to + * become ready in one shared poll loop. The wall-clock wait scales with the slowest worker rather + * than the sum of the per-worker waits. * * @param ports Ports to use, one per worker * @return The thread per port, in the same order as {@code ports}. @@ -1864,21 +1867,22 @@ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArg } /** - * Start N federated worker threads in the same JVM in parallel, then wait for them all to be ready. + * Start N federated worker threads in the same JVM back to back, then wait for all of them to + * become ready in one shared poll loop. * * @param ports Ports to use, one per worker * @param otherArgs Extra worker CLI args (applied to every worker), or null - * @param timeoutMs Upper bound on the wait, in ms (clamped to a sane floor) + * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside + * {@link FederatedWorkerUtils}. * @return The thread per port, in the same order as {@code ports}. */ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs, int timeoutMs) { Thread[] threads = new Thread[ports.length]; for(int i = 0; i < ports.length; i++) { threads[i] = spawnLocalFedWorkerThread(ports[i], otherArgs); - // Tiny stagger between in-JVM thread spawns: lets each worker pull through the shared - // DMLScript / FederatedWorker static-init (LineageCacheConfig setters etc.) without all - // threads racing through the same global state simultaneously. Cheap relative to total - // startup; restores the implicit ordering the old per-thread sleep provided. + // Sleep THREAD_SPAWN_STAGGER_MS between in-JVM thread spawns to reduce contention on + // shared static initialization in DMLScript / FederatedWorker (e.g. LineageCacheConfig + // setters) when multiple worker threads enter main() concurrently. if(i + 1 < ports.length) { try { java.util.concurrent.TimeUnit.MILLISECONDS.sleep(THREAD_SPAWN_STAGGER_MS); @@ -1889,13 +1893,13 @@ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArg } } } - FederatedTestUtils.waitForWorkers(threads, ports, timeoutMs); + FederatedWorkerUtils.waitForWorkers(threads, ports, timeoutMs); return threads; } private static final int THREAD_SPAWN_STAGGER_MS = 25; - /** Spawn a federated worker thread in this JVM and return immediately, without waiting for the port to bind. */ + /** Spawn a federated worker thread in this JVM and return without waiting for the port to bind. */ private static Thread spawnLocalFedWorkerThread(int port, String[] otherArgs) { ArrayList args = new ArrayList<>(); args.add("-w"); @@ -1938,9 +1942,9 @@ public static boolean isAlive(Process... processes) { /** * Start java worker in same JVM. * - *

Returns once the worker's TCP port is accepting connections (the worker only opens the port after - * Netty's bind has fully completed) or after the default federated worker timeout elapses, in which case - * a {@link RuntimeException} is thrown. The port is extracted from {@code args} (expects {@code "-w" }). + *

Returns once the worker's TCP port accepts connections (the worker opens the port after + * Netty's bind completes), or throws a {@link RuntimeException} after the default federated worker + * timeout elapses. The port is extracted from {@code args}, which must contain {@code "-w" }. * * @param args the command line arguments * @return the thread associated with the worker. @@ -1952,11 +1956,12 @@ public static Thread startLocalFedWorkerWithArgs(String[] args) { main(args); } catch(IOException e) { - // suppressed: matches previous behaviour of this helper + // IOException from main() is not propagated here; startup failures surface as a + // readiness timeout from waitForWorker below. } }); t.start(); - FederatedTestUtils.waitForWorker(t, port, FED_WORKER_WAIT); + FederatedWorkerUtils.waitForWorker(t, port, FED_WORKER_WAIT); return t; } diff --git a/src/test/java/org/apache/sysds/test/FederatedTestUtils.java b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java similarity index 64% rename from src/test/java/org/apache/sysds/test/FederatedTestUtils.java rename to src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java index d328fdee69e..bd12f7c3bdc 100644 --- a/src/test/java/org/apache/sysds/test/FederatedTestUtils.java +++ b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java @@ -25,51 +25,47 @@ import java.util.function.BooleanSupplier; /** - * Test-only helpers for waiting on federated worker startup. + * Test helpers that block until a federated worker is accepting TCP connections on its port. * - *

The federated worker only opens its TCP port once Netty's - * {@code bind().sync()} has fully completed, so a successful TCP connect to - * that port is a precise readiness signal. Tests use this instead of fixed - * sleeps so they finish as soon as the worker is up and fail fast if the - * worker never comes up (or dies during startup). - * - *

This lives in the test tree on purpose; the production - * {@code FederatedWorker} carries no test-only callback hooks. + *

The federated worker opens its TCP port after Netty's {@code bind().sync()} returns; a successful + * TCP connect to that port therefore indicates that the worker is ready to accept requests. The methods + * here poll for that signal and throw {@link RuntimeException} on timeout or if the underlying + * {@code Process}/{@code Thread} exits before the port becomes ready. */ -public final class FederatedTestUtils { +public final class FederatedWorkerUtils { - /** How often to retry the TCP connect while waiting (ms). */ + /** Sleep between successive poll rounds, in milliseconds. */ private static final int POLL_INTERVAL_MS = 25; + /** Per-attempt {@link Socket#connect} timeout, in milliseconds. */ + private static final int CONNECT_TIMEOUT_MS = 25; + /** - * Lower clamp on the caller-supplied timeout. Several historical call sites - * pass values like {@code FED_WORKER_WAIT_S = 50}, which were intended as a - * "short fixed sleep" before the next worker was started in parallel. With - * polling we exit as soon as the worker is ready, so only the upper bound - * matters; clamping prevents those tiny legacy values from turning into - * false-positive timeouts on slow CI machines. + * Minimum value applied to the caller-supplied {@code timeoutMs}. The wait returns as soon as the + * worker accepts a connection, so this only affects the upper bound used when a worker never becomes + * ready. Set to 3s so a worker that fails to bind is reported quickly while still leaving room for + * Netty bind under typical local/CI conditions. */ - private static final int MIN_TIMEOUT_MS = 30_000; + private static final int MIN_TIMEOUT_MS = 3_000; - private FederatedTestUtils() { + private FederatedWorkerUtils() { // utility class } /** - * Block until a federated worker is accepting TCP connections on - * {@code port}, or throw a {@link RuntimeException} after the timeout. + * Block until a federated worker is accepting TCP connections on {@code port}, or throw a + * {@link RuntimeException} after the effective timeout elapses. * * @param port port the federated worker is expected to bind - * @param timeoutMs upper bound on the wait, in ms (clamped to a sane floor) + * @param timeoutMs upper bound on the wait, in ms; raised to {@link #MIN_TIMEOUT_MS} if smaller */ public static void waitForWorker(int port, int timeoutMs) { waitForWorker(port, timeoutMs, () -> true, "worker"); } /** - * Block until a federated worker is accepting TCP connections on - * {@code port}. Fails fast if the supplied {@code aliveCheck} reports the - * worker is no longer alive. + * Block until a federated worker is accepting TCP connections on {@code port}. Returns early with + * a {@link RuntimeException} if {@code aliveCheck} reports the worker is no longer alive. */ public static void waitForWorker(int port, int timeoutMs, BooleanSupplier aliveCheck, String workerKind) { final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS); @@ -88,30 +84,33 @@ public static void waitForWorker(int port, int timeoutMs, BooleanSupplier aliveC + " did not become ready within " + effectiveTimeout + "ms."); } - /** Convenience: also fail fast if the worker process dies during startup. */ + /** Overload that also returns early if the given worker process exits before the port is ready. */ public static void waitForWorker(Process process, int port, int timeoutMs) { waitForWorker(port, timeoutMs, process::isAlive, "worker process"); } - /** Convenience: also fail fast if the worker thread dies during startup. */ + /** Overload that also returns early if the given worker thread exits before the port is ready. */ public static void waitForWorker(Thread thread, int port, int timeoutMs) { waitForWorker(port, timeoutMs, thread::isAlive, "worker thread"); } /** - * Block until all the listed federated workers are accepting TCP - * connections. Polls all ports in a single shared loop, so when N workers - * are spawned in parallel and then this is called, total wall-clock wait is - * bounded by the slowest worker -- not the sum of individual startup times. + * Block until every listed federated worker is accepting TCP connections. All ports are polled in + * one shared loop, so the wall-clock wait is bounded by the slowest worker rather than the sum of + * individual waits. * * @param ports ports the workers are expected to bind - * @param timeoutMs upper bound on the wait, in ms (clamped to a sane floor) + * @param timeoutMs upper bound on the wait, in ms; raised to {@link #MIN_TIMEOUT_MS} if smaller */ public static void waitForWorkers(int[] ports, int timeoutMs) { waitForWorkers(ports, timeoutMs, i -> true, "workers"); } - /** Convenience: also fail fast if any of the worker processes dies during startup. */ + /** + * Overload that also returns early if any of the worker processes exits before its port is ready. + * + * @throws IllegalArgumentException if {@code processes.length != ports.length} + */ public static void waitForWorkers(Process[] processes, int[] ports, int timeoutMs) { if(processes.length != ports.length) { throw new IllegalArgumentException( @@ -120,7 +119,11 @@ public static void waitForWorkers(Process[] processes, int[] ports, int timeoutM waitForWorkers(ports, timeoutMs, i -> processes[i].isAlive(), "worker processes"); } - /** Convenience: also fail fast if any of the worker threads dies during startup. */ + /** + * Overload that also returns early if any of the worker threads exits before its port is ready. + * + * @throws IllegalArgumentException if {@code threads.length != ports.length} + */ public static void waitForWorkers(Thread[] threads, int[] ports, int timeoutMs) { if(threads.length != ports.length) { throw new IllegalArgumentException( @@ -130,10 +133,10 @@ public static void waitForWorkers(Thread[] threads, int[] ports, int timeoutMs) } /** - * Bulk variant: takes a per-index alive check so the caller can plug in - * {@code Process} or {@code Thread} liveness. Marks each port ready as soon - * as it accepts a connection, and only sleeps between rounds (not between - * port checks), so all ports race in parallel. + * Bulk variant taking a per-index liveness predicate so callers can plug in either {@code Process} + * or {@code Thread} liveness. Each port flips to ready as soon as it accepts a connection; the loop + * sleeps once per round (after probing all pending ports) rather than between individual port + * probes. */ public static void waitForWorkers(int[] ports, int timeoutMs, java.util.function.IntPredicate aliveCheck, String workerKind) { @@ -175,7 +178,7 @@ public static void waitForWorkers(int[] ports, int timeoutMs, java.util.function private static boolean tryConnect(int port) { try(Socket s = new Socket()) { - s.connect(new InetSocketAddress("localhost", port), POLL_INTERVAL_MS); + s.connect(new InetSocketAddress("localhost", port), CONNECT_TIMEOUT_MS); return true; } catch(IOException e) {