diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index c7f62b02a2b..46b18cf3476 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -1642,9 +1642,11 @@ protected Process startLocalFedWorker(int port){ /** * Start a new JVM for a federated worker at the port. - * - * @param port Port to use for the JVM - * @param sleep The sleep time to wait for the worker to start + * + * @param port Port to use for the JVM + * @param sleep Upper bound on the wait for the worker to become ready, in ms; raised to a + * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name retained + * for source compatibility with the previous fixed-sleep API. * @return The process containing the worker */ protected Process startLocalFedWorker(int port, int sleep){ @@ -1665,18 +1667,65 @@ protected Process startLocalFedWorker(int port, String[] addArgs) { /** * Start new JVM for a federated worker at the port. - * - * @param port Port to use for the JVM - * @param addArgs The arguments to add - * @param sleep The time to wait for the process to start + * + *

Returns once the worker's TCP port accepts connections (the worker opens the port after + * Netty's bind completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses. + * + * @param port Port to use for the JVM + * @param addArgs The arguments to add + * @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a + * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name + * retained for source compatibility with the previous fixed-sleep API. * @return the process associated with the worker. */ - protected static Process startLocalFedWorker(int port, String[] addArgs, int sleep) { - Process process = null; + protected static Process startLocalFedWorker(int port, String[] addArgs, int timeoutMs) { + Process process = spawnLocalFedWorker(port, addArgs); + FederatedWorkerUtils.waitForWorker(process, port, timeoutMs); + return process; + } + + /** + * Start N federated worker JVMs back to back, then wait for all of them to become ready in one + * shared poll loop. The wall-clock wait scales with the slowest worker rather than the sum of the + * per-worker waits. + * + * @param ports Ports to use, one per worker + * @return The process per port, in the same order as {@code ports}. + */ + protected static Process[] startLocalFedWorkers(int[] ports) { + return startLocalFedWorkers(ports, null, FED_WORKER_WAIT); + } + + /** @see #startLocalFedWorkers(int[], String[], int) */ + protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs) { + return startLocalFedWorkers(ports, addArgs, FED_WORKER_WAIT); + } + + /** + * Start N federated worker JVMs back to back, then wait for all of them to become ready in one + * shared poll loop. + * + * @param ports Ports to use, one per worker + * @param addArgs Extra worker CLI args (applied to every worker), or null + * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside + * {@link FederatedWorkerUtils}. + * @return The process per port, in the same order as {@code ports}. + */ + protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs, int timeoutMs) { + Process[] processes = new Process[ports.length]; + for(int i = 0; i < ports.length; i++) { + processes[i] = spawnLocalFedWorker(ports[i], addArgs); + } + FederatedWorkerUtils.waitForWorkers(processes, ports, timeoutMs); + return processes; + } + + /** Spawn a federated worker JVM and return without waiting for the port to bind. */ + private static Process spawnLocalFedWorker(int port, String[] addArgs) { String separator = System.getProperty("file.separator"); String classpath = System.getProperty("java.class.path"); String path = System.getProperty("java.home") + separator + "bin" + separator + "java"; - String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m", + String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m", "--add-opens=java.base/java.nio=ALL-UNNAMED" , "--add-opens=java.base/java.io=ALL-UNNAMED" , "--add-opens=java.base/java.util=ALL-UNNAMED" , @@ -1701,19 +1750,14 @@ protected static Process startLocalFedWorker(int port, String[] addArgs, int sle DMLScript.class.getName(), "-w", Integer.toString(port), "-stats"}); if(addArgs != null) args = ArrayUtils.addAll(args, addArgs); - - ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO(); + ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO(); try { - process = processBuilder.start(); - // Give some time to startup the worker. - sleep(sleep); + return processBuilder.start(); } - catch(IOException | InterruptedException e) { - e.printStackTrace(); + catch(IOException e) { + throw new RuntimeException("Failed to launch federated worker process on port " + port, e); } - isAlive(process); - return process; } /** @@ -1743,7 +1787,7 @@ protected Process startLocalFedMonitoring(int port, String[] addArgs) { } /** - * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared. * * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. * @@ -1769,12 +1813,14 @@ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs) { } /** - * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared. * * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. * * @param port Port to use - * @param sleep The amount of time to wait for the worker startup. in Milliseconds + * @param sleep Upper bound on the wait for the worker to become ready, in ms; raised to a minimum + * value enforced inside {@link FederatedWorkerUtils}. Parameter name retained for source + * compatibility with the previous fixed-sleep API. * @return The thread associated with the worker. */ public static Thread startLocalFedWorkerThread(int port, int sleep) { @@ -1782,50 +1828,99 @@ public static Thread startLocalFedWorkerThread(int port, int sleep) { } /** - * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! - * - * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. - * + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared. + * + *

Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is + * disabled. + * + *

Returns once the worker's TCP port accepts connections (the worker opens the port after Netty's bind + * completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses. + * * @param port Port to use * @param otherArgs The command line arguments to start the worker with - * @param sleep The amount of time to wait for the worker startup. in Milliseconds + * @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a + * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name + * retained for source compatibility with the previous fixed-sleep API. * @return The thread associated with the worker. */ - public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) { + public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int timeoutMs) { + Thread t = spawnLocalFedWorkerThread(port, otherArgs); + FederatedWorkerUtils.waitForWorker(t, port, timeoutMs); + return t; + } + + /** + * Start N federated worker threads in the same JVM back to back, then wait for all of them to + * become ready in one shared poll loop. The wall-clock wait scales with the slowest worker rather + * than the sum of the per-worker waits. + * + * @param ports Ports to use, one per worker + * @return The thread per port, in the same order as {@code ports}. + */ + public static Thread[] startLocalFedWorkerThreads(int[] ports) { + return startLocalFedWorkerThreads(ports, null, FED_WORKER_WAIT); + } + /** @see #startLocalFedWorkerThreads(int[], String[], int) */ + public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs) { + return startLocalFedWorkerThreads(ports, otherArgs, FED_WORKER_WAIT); + } + + /** + * Start N federated worker threads in the same JVM back to back, then wait for all of them to + * become ready in one shared poll loop. + * + * @param ports Ports to use, one per worker + * @param otherArgs Extra worker CLI args (applied to every worker), or null + * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside + * {@link FederatedWorkerUtils}. + * @return The thread per port, in the same order as {@code ports}. + */ + public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs, int timeoutMs) { + Thread[] threads = new Thread[ports.length]; + for(int i = 0; i < ports.length; i++) { + threads[i] = spawnLocalFedWorkerThread(ports[i], otherArgs); + // Sleep THREAD_SPAWN_STAGGER_MS between in-JVM thread spawns to reduce contention on + // shared static initialization in DMLScript / FederatedWorker (e.g. LineageCacheConfig + // setters) when multiple worker threads enter main() concurrently. + if(i + 1 < ports.length) { + try { + java.util.concurrent.TimeUnit.MILLISECONDS.sleep(THREAD_SPAWN_STAGGER_MS); + } + catch(InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while spawning federated worker threads", e); + } + } + } + FederatedWorkerUtils.waitForWorkers(threads, ports, timeoutMs); + return threads; + } + + private static final int THREAD_SPAWN_STAGGER_MS = 25; + + /** Spawn a federated worker thread in this JVM and return without waiting for the port to bind. */ + private static Thread spawnLocalFedWorkerThread(int port, String[] otherArgs) { ArrayList args = new ArrayList<>(); - args.add("-w"); args.add(Integer.toString(port)); - if(otherArgs != null) - for( String s : otherArgs) + for(String s : otherArgs) args.add(s); String[] finalArguments = args.toArray(new String[args.size()]); Statistics.allowWorkerStatistics = false; - try { - Thread t = new Thread(() -> { - try { - main(finalArguments); - } - catch(Exception e) { - LOG.error("Exception in startup of federated worker", e); - } - }); - t.start(); - java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep); - if(!t.isAlive()) - throw new RuntimeException("Failed starting federated worker"); - return t; - } - catch(InterruptedException e) { - e.printStackTrace(); - fail("Failed to start federated worker : " + e); - // should never happen - return null; - } + Thread t = new Thread(() -> { + try { + main(finalArguments); + } + catch(Exception e) { + LOG.error("Exception in startup of federated worker", e); + } + }); + t.start(); + return t; } public static boolean isAlive(Thread... threads){ @@ -1846,28 +1941,44 @@ public static boolean isAlive(Process... processes) { /** * Start java worker in same JVM. - * + * + *

Returns once the worker's TCP port accepts connections (the worker opens the port after + * Netty's bind completes), or throws a {@link RuntimeException} after the default federated worker + * timeout elapses. The port is extracted from {@code args}, which must contain {@code "-w" }. + * * @param args the command line arguments - * @return the thread associated with the process.s + * @return the thread associated with the worker. */ public static Thread startLocalFedWorkerWithArgs(String[] args) { - Thread t = null; + final int port = extractWorkerPort(args); + Thread t = new Thread(() -> { + try { + main(args); + } + catch(IOException e) { + // IOException from main() is not propagated here; startup failures surface as a + // readiness timeout from waitForWorker below. + } + }); + t.start(); + FederatedWorkerUtils.waitForWorker(t, port, FED_WORKER_WAIT); + return t; + } - try { - t = new Thread(() -> { + private static int extractWorkerPort(String[] args) { + for(int i = 0; i < args.length - 1; i++) { + if("-w".equals(args[i])) { try { - main(args); + return Integer.parseInt(args[i + 1]); } - catch(IOException e) { + catch(NumberFormatException e) { + throw new IllegalArgumentException( + "Federated worker args contain non-numeric port after -w: " + args[i + 1], e); } - }); - t.start(); - java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT); - } - catch(InterruptedException e) { - // Should happen at closing of the worker so don't print + } } - return t; + throw new IllegalArgumentException("Federated worker args must contain '-w ': " + + Arrays.toString(args)); } private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) { diff --git a/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java new file mode 100644 index 00000000000..bd12f7c3bdc --- /dev/null +++ b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.function.BooleanSupplier; + +/** + * Test helpers that block until a federated worker is accepting TCP connections on its port. + * + *

The federated worker opens its TCP port after Netty's {@code bind().sync()} returns; a successful + * TCP connect to that port therefore indicates that the worker is ready to accept requests. The methods + * here poll for that signal and throw {@link RuntimeException} on timeout or if the underlying + * {@code Process}/{@code Thread} exits before the port becomes ready. + */ +public final class FederatedWorkerUtils { + + /** Sleep between successive poll rounds, in milliseconds. */ + private static final int POLL_INTERVAL_MS = 25; + + /** Per-attempt {@link Socket#connect} timeout, in milliseconds. */ + private static final int CONNECT_TIMEOUT_MS = 25; + + /** + * Minimum value applied to the caller-supplied {@code timeoutMs}. The wait returns as soon as the + * worker accepts a connection, so this only affects the upper bound used when a worker never becomes + * ready. Set to 3s so a worker that fails to bind is reported quickly while still leaving room for + * Netty bind under typical local/CI conditions. + */ + private static final int MIN_TIMEOUT_MS = 3_000; + + private FederatedWorkerUtils() { + // utility class + } + + /** + * Block until a federated worker is accepting TCP connections on {@code port}, or throw a + * {@link RuntimeException} after the effective timeout elapses. + * + * @param port port the federated worker is expected to bind + * @param timeoutMs upper bound on the wait, in ms; raised to {@link #MIN_TIMEOUT_MS} if smaller + */ + public static void waitForWorker(int port, int timeoutMs) { + waitForWorker(port, timeoutMs, () -> true, "worker"); + } + + /** + * Block until a federated worker is accepting TCP connections on {@code port}. Returns early with + * a {@link RuntimeException} if {@code aliveCheck} reports the worker is no longer alive. + */ + public static void waitForWorker(int port, int timeoutMs, BooleanSupplier aliveCheck, String workerKind) { + final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS); + final long deadline = System.currentTimeMillis() + effectiveTimeout; + while(System.currentTimeMillis() < deadline) { + if(!aliveCheck.getAsBoolean()) { + throw new RuntimeException( + "Federated " + workerKind + " on port " + port + " died before becoming ready."); + } + if(tryConnect(port)) { + return; + } + sleepQuietly(); + } + throw new RuntimeException("Federated " + workerKind + " on port " + port + + " did not become ready within " + effectiveTimeout + "ms."); + } + + /** Overload that also returns early if the given worker process exits before the port is ready. */ + public static void waitForWorker(Process process, int port, int timeoutMs) { + waitForWorker(port, timeoutMs, process::isAlive, "worker process"); + } + + /** Overload that also returns early if the given worker thread exits before the port is ready. */ + public static void waitForWorker(Thread thread, int port, int timeoutMs) { + waitForWorker(port, timeoutMs, thread::isAlive, "worker thread"); + } + + /** + * Block until every listed federated worker is accepting TCP connections. All ports are polled in + * one shared loop, so the wall-clock wait is bounded by the slowest worker rather than the sum of + * individual waits. + * + * @param ports ports the workers are expected to bind + * @param timeoutMs upper bound on the wait, in ms; raised to {@link #MIN_TIMEOUT_MS} if smaller + */ + public static void waitForWorkers(int[] ports, int timeoutMs) { + waitForWorkers(ports, timeoutMs, i -> true, "workers"); + } + + /** + * Overload that also returns early if any of the worker processes exits before its port is ready. + * + * @throws IllegalArgumentException if {@code processes.length != ports.length} + */ + public static void waitForWorkers(Process[] processes, int[] ports, int timeoutMs) { + if(processes.length != ports.length) { + throw new IllegalArgumentException( + "processes/ports length mismatch: " + processes.length + " vs " + ports.length); + } + waitForWorkers(ports, timeoutMs, i -> processes[i].isAlive(), "worker processes"); + } + + /** + * Overload that also returns early if any of the worker threads exits before its port is ready. + * + * @throws IllegalArgumentException if {@code threads.length != ports.length} + */ + public static void waitForWorkers(Thread[] threads, int[] ports, int timeoutMs) { + if(threads.length != ports.length) { + throw new IllegalArgumentException( + "threads/ports length mismatch: " + threads.length + " vs " + ports.length); + } + waitForWorkers(ports, timeoutMs, i -> threads[i].isAlive(), "worker threads"); + } + + /** + * Bulk variant taking a per-index liveness predicate so callers can plug in either {@code Process} + * or {@code Thread} liveness. Each port flips to ready as soon as it accepts a connection; the loop + * sleeps once per round (after probing all pending ports) rather than between individual port + * probes. + */ + public static void waitForWorkers(int[] ports, int timeoutMs, java.util.function.IntPredicate aliveCheck, + String workerKind) { + final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS); + final long deadline = System.currentTimeMillis() + effectiveTimeout; + final boolean[] ready = new boolean[ports.length]; + int remaining = ports.length; + while(remaining > 0 && System.currentTimeMillis() < deadline) { + boolean madeProgress = false; + for(int i = 0; i < ports.length; i++) { + if(ready[i]) { + continue; + } + if(!aliveCheck.test(i)) { + throw new RuntimeException("Federated " + workerKind + " on port " + ports[i] + + " died before becoming ready."); + } + if(tryConnect(ports[i])) { + ready[i] = true; + remaining--; + madeProgress = true; + } + } + if(remaining > 0 && !madeProgress) { + sleepQuietly(); + } + } + if(remaining > 0) { + StringBuilder sb = new StringBuilder("Federated ").append(workerKind) + .append(" did not all become ready within ").append(effectiveTimeout).append("ms. Pending ports:"); + for(int i = 0; i < ports.length; i++) { + if(!ready[i]) { + sb.append(' ').append(ports[i]); + } + } + throw new RuntimeException(sb.toString()); + } + } + + private static boolean tryConnect(int port) { + try(Socket s = new Socket()) { + s.connect(new InetSocketAddress("localhost", port), CONNECT_TIMEOUT_MS); + return true; + } + catch(IOException e) { + return false; + } + } + + private static void sleepQuietly() { + try { + Thread.sleep(POLL_INTERVAL_MS); + } + catch(InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for federated worker", ie); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java index 5e880166afe..3b53c71c0ea 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java @@ -120,8 +120,7 @@ public void federatedAlsCG(String testname, ExecMode execMode) fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(testname); @@ -153,7 +152,7 @@ public void federatedAlsCG(String testname, ExecMode execMode) HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_!=")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java index f2f35700119..0537ecf18de 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java @@ -114,10 +114,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -139,7 +136,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check that federated input files are still existing Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java index 73bf8e91de6..2c81da6333e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java @@ -102,10 +102,7 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -140,7 +137,7 @@ private void runAggregateOperationTest(ExecMode execMode) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java index 6d8e8165305..b54cb3d0a64 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java @@ -100,8 +100,7 @@ public void federatedGLM(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -122,7 +121,7 @@ public void federatedGLM(Types.ExecMode execMode) { // compare via files compareResults(1e-2); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java index c8605ac3d88..c1fc83ae3c3 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java @@ -113,8 +113,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -155,7 +154,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) { // compare via files // compareResults(1e-9); --> randomized - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java index f7040e6a4dc..0b38317f73b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java @@ -104,8 +104,7 @@ public void federatedL2SVM(Types.ExecMode execMode, boolean singleWorker) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -126,7 +125,7 @@ public void federatedL2SVM(Types.ExecMode execMode, boolean singleWorker) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java index 8b02246e0aa..f3e921dd935 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java @@ -109,10 +109,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -134,7 +131,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE // compare via files compareResults(1e-2); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check correct federated operations diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java index a3e91ef37d7..74882637d95 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java @@ -95,8 +95,7 @@ public void federatedLogReg(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -117,7 +116,7 @@ public void federatedLogReg(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue("contains fed_ba+*", heavyHittersContainsString("fed_ba+*")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java index c5344f9d847..97a4d84a0b9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java @@ -97,8 +97,7 @@ public void federatedMSVM(Types.ExecMode execMode, boolean singleWorker) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -119,7 +118,7 @@ public void federatedMSVM(Types.ExecMode execMode, boolean singleWorker) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java index 2a4186680c5..9ddd59483ce 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java @@ -102,10 +102,7 @@ public void federatedPCA(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -127,7 +124,7 @@ public void federatedPCA(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java index 19fb72ce47b..183f9900a6d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java @@ -113,8 +113,7 @@ public void federatedPNMF(ExecMode execMode) fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -141,7 +140,7 @@ public void federatedPNMF(ExecMode execMode) HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_wcemm")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java index a4a8236e3a8..21f2e392250 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java @@ -100,10 +100,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -126,7 +123,7 @@ public void federatedL2SVM(Types.ExecMode execMode) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_uacmax")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java index db071220d00..445abfa1485 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java @@ -115,10 +115,7 @@ private void runAggregateOperationTest(ExecMode execMode, boolean federatedCompi int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) @@ -157,7 +154,7 @@ private void runAggregateOperationTest(ExecMode execMode, boolean federatedCompi Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); OptimizerUtils.FEDERATED_COMPILATION = false; rtplatform = platformOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java index b8eef26f35a..a87328a8aa4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java @@ -117,8 +117,7 @@ public void federatedL2SVM(Types.ExecMode execMode, String testName) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); TestConfiguration config = availableTestConfigurations.get(testName); loadTestConfiguration(config); @@ -140,7 +139,7 @@ public void federatedL2SVM(Types.ExecMode execMode, String testName) { // compare via files compareResults(1e-9); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java index 2f6e5e46173..94fab208e09 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java @@ -158,8 +158,7 @@ private void testFederatedCodegen(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -188,7 +187,7 @@ private void testFederatedCodegen(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java index 61722dbc467..a23219e1587 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java @@ -194,8 +194,7 @@ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -228,7 +227,7 @@ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations if(test_num >= 0 && test_num < 100) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java index 33a551d28d0..74cd2c0a420 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java @@ -143,8 +143,7 @@ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -173,7 +172,7 @@ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java index cef5fd5e99f..dd08ae58636 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java @@ -146,8 +146,7 @@ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -176,7 +175,7 @@ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) { HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java index b47a718c0ca..185c15557d1 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java @@ -147,8 +147,7 @@ private void testFederatedCodegenRowwise(ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread thread2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}); getAndLoadTestConfiguration(TEST_NAME); @@ -178,7 +177,7 @@ private void testFederatedCodegenRowwise(ExecMode exec_mode) { TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - TestUtils.shutdownThreads(thread1, thread2); + TestUtils.shutdownThreads(workers); // check for federated operations if(row_partitioned) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java index 53dce3f01c7..bcc2b61adf2 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java @@ -118,7 +118,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - Thread t1 = null, t2 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(testName); @@ -128,8 +128,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -164,7 +163,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java index 9a9ff18d28b..b6c239854b3 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java @@ -125,7 +125,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - Thread t1 = null, t2 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(testName); @@ -135,8 +135,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -158,7 +157,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri // fail("The following expected heavy hitters are missing: " // + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java index 91143396208..39ab4901480 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java @@ -170,7 +170,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - Thread t1 = null, t2 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(testName); @@ -180,8 +180,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -205,7 +204,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri // + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java index 5b54f14d059..0a2bd230da4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java @@ -240,7 +240,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, if(rtplatform == Types.ExecMode.SPARK) { DMLScript.USE_LOCAL_SPARK_CONFIG = true; } - Thread t1 = null, t2 = null; + Thread[] workers = null; try{ getAndLoadTestConfiguration(testName); @@ -250,8 +250,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2); + workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); // Run actual dml script with federated matrix fullDMLScriptName = HOME + testName + ".dml"; @@ -275,7 +274,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, fail("The following expected heavy hitters are missing: " + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java index ff96ad8af2f..566d3ff323c 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java @@ -91,8 +91,7 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); String host = "localhost"; try { @@ -137,6 +136,6 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) { resetExecMode(oldPlatform); } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java index 5d4d2c4e9c3..5f5c09e07cb 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java @@ -96,8 +96,7 @@ public void federatedRead(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); String host = "localhost"; @@ -134,7 +133,7 @@ public void federatedRead(Types.ExecMode execMode) { resetExecMode(oldPlatform); } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } /** diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java index 42165d6b68f..f2654b9bb66 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java @@ -102,8 +102,7 @@ public void federatedGet(ExecMode execMode, double sparsity) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); getAndLoadTestConfiguration(TEST_NAME); @@ -136,7 +135,7 @@ public void federatedGet(ExecMode execMode, double sparsity) { compareNNZ(refNNZ, fedNNZ); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java index d8bb7431472..68fb3a7da63 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java @@ -83,8 +83,7 @@ public void federatedWrite(ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT); try { @@ -122,6 +121,6 @@ public void federatedWrite(ExecMode execMode) { resetExecMode(oldPlatform); } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java index c3a4756a2d5..866bbe00139 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java @@ -38,8 +38,6 @@ import org.apache.sysds.test.AutomatedTestBase; import org.junit.After; -import com.google.crypto.tink.subtle.Random; - public abstract class MultiTenantTestBase extends AutomatedTestBase { protected static final Log LOG = LogFactory.getLog(MultiTenantTestBase.class.getName()); @@ -63,7 +61,8 @@ protected int[] startFedWorkers(int numFedWorkers) { } /** - * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses + * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses. + * Workers are spawned together and their port-bind is awaited in parallel. * * @param numFedWorkers the number of federated workers to start * @return int[] the ports of the created federated workers @@ -72,10 +71,10 @@ protected int[] startFedWorkers(int numFedWorkers, String[] addArgs) { int[] ports = new int[numFedWorkers]; for(int counter = 0; counter < numFedWorkers; counter++) { ports[counter] = getRandomAvailablePort(); - // start process but only wait long for last one. - Process tmpProcess = startLocalFedWorker(ports[counter], addArgs, - counter == numFedWorkers - 1 ? (FED_WORKER_WAIT + Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S); - workerProcesses.add(tmpProcess); + } + Process[] processes = startLocalFedWorkers(ports, addArgs); + for(Process p : processes) { + workerProcesses.add(p); } return ports; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java index 8991d28194f..20945303b0b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java @@ -105,11 +105,10 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -134,7 +133,7 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java index e038d9efda0..d47a7efcfa4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java @@ -97,11 +97,10 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -127,7 +126,7 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java index 434eb08e262..1a88f7a0d39 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java @@ -90,11 +90,10 @@ public void federatedBroadcast(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -117,7 +116,7 @@ public void federatedBroadcast(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java index a7bc0d80640..392e53aa77b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java @@ -97,12 +97,10 @@ public void federatedMultiply(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); - try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -127,7 +125,7 @@ public void federatedMultiply(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java index ccfd4a6c430..832ae98d8ed 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java @@ -122,10 +122,9 @@ public void federatedMultiply(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -152,7 +151,7 @@ public void federatedMultiply(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java index 1d9f951e78a..b52a1d60262 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java @@ -118,14 +118,10 @@ public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted) int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT + 1000); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}, null, FED_WORKER_WAIT + 1000); - try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // reference file should not be written to hdfs, so we set platform here @@ -187,7 +183,7 @@ public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted) } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java index f52a74bd4fc..5254378cf27 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java @@ -174,13 +174,10 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -244,7 +241,7 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java index 30d6b8fc0ec..6f9744ded1a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java @@ -165,13 +165,10 @@ else if(dataType == DataType.FRAME) int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -212,7 +209,7 @@ else if(dataType == DataType.FRAME) } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); } } finally { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java index 8fb8a80663e..06f86dd61b4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java @@ -220,14 +220,10 @@ private void runMisAlignedTest(OpType type, ExecMode execMode, MisAlignmentType int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); - try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // Run reference dml script with normal matrix @@ -281,7 +277,7 @@ private void runMisAlignedTest(OpType type, ExecMode execMode, MisAlignmentType } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java index 05e4954af15..cfdac1932ff 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java @@ -114,12 +114,11 @@ private void federatedMultiply(Types.ExecMode execMode, boolean federatedCompila int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -145,7 +144,7 @@ private void federatedMultiply(Types.ExecMode execMode, boolean federatedCompila } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java index 530a8b7b55e..2447143de2b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java @@ -108,14 +108,11 @@ private void runProdTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -154,7 +151,7 @@ private void runProdTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java index 25d460526e9..6f13fce81b2 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java @@ -164,12 +164,13 @@ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double port2 = getRandomAvailablePort(); port3 = getRandomAvailablePort(); port4 = getRandomAvailablePort(); - t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); + t1 = workers[0]; + t2 = workers[1]; + t3 = workers[2]; + t4 = workers[3]; - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); programArgs1 = new String[] {"-explain", "-stats", "100", "-args", String.valueOf(p), expected("S"), diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java index dadbd6b5905..ebf4708b3e9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java @@ -132,12 +132,13 @@ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double port2 = getRandomAvailablePort(); port3 = getRandomAvailablePort(); port4 = getRandomAvailablePort(); - t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); + t1 = workers[0]; + t2 = workers[1]; + t3 = workers[2]; + t4 = workers[3]; - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); programArgs1 = new String[] {"-explain", "-stats", "100", "-args", String.valueOf(p), expected("S"), diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java index d8ef8ca4c20..be69fac747d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java @@ -109,14 +109,11 @@ public void federatedRCBind(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // we need the reference file to not be written to hdfs, so we get the correct format @@ -158,7 +155,7 @@ public void federatedRCBind(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java index 46bf7a45650..ce6aba438e0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java @@ -112,14 +112,11 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); ProgramBlock.CHECK_MATRIX_PROPERTIES = true; try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // reference file should not be written to hdfs, so we set platform here @@ -159,7 +156,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.FEDERATED_COMPILATION = false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java index 71f7c583665..0714a6e6fb7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java @@ -121,14 +121,11 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -166,7 +163,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java index d2ec98a4a8d..c5347d6276a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java @@ -108,13 +108,10 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -153,7 +150,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java index 5d6887c3e2d..fb8bc619ec5 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java @@ -101,14 +101,11 @@ public void federatedReshape(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // reference file should not be written to hdfs, so we set platform here rtplatform = execMode; @@ -146,7 +143,7 @@ public void federatedReshape(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java index e4b7ed5e24a..c4f30a574b0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java @@ -129,14 +129,11 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -176,7 +173,7 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java index b8fe21ef52d..7a72a868b6e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java @@ -178,13 +178,10 @@ private void runAggregateOperationTest(IndexType indexType, DataType dataType, E int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -236,7 +233,7 @@ private void runAggregateOperationTest(IndexType indexType, DataType dataType, E } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java index f242710338d..aa367ea7d4d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java @@ -131,14 +131,11 @@ private void runRollTest(ExecMode execMode, boolean activateFedCompilation) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if (!isAlive(t1, t2, t3, t4)) + if (!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; if (rtplatform == ExecMode.SPARK) { @@ -177,7 +174,7 @@ private void runRollTest(ExecMode execMode, boolean activateFedCompilation) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java index 25b02ca1537..37b6f8826e4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java @@ -108,13 +108,10 @@ private void runRowIndexTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -155,7 +152,7 @@ private void runRowIndexTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java index fc9e4a73ba5..462181e75ea 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java @@ -106,11 +106,10 @@ public void federatedSplit(Types.ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // Run reference dml script with normal matrix @@ -141,7 +140,7 @@ else if(execMode != Types.ExecMode.SPARK) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java index ac9706a44ae..dec5fac4432 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java @@ -96,10 +96,9 @@ public void federatedLogReg(Types.ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -134,7 +133,7 @@ public void federatedLogReg(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java index f8acc4623a7..84b632c5e2c 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java @@ -91,12 +91,10 @@ private void runAggregateOperationTest(ExecMode execMode) { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3}); try { - if(!isAlive(t1, t2, t3)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(false, DataExpression.DEFAULT_DELIM_DELIMITER, @@ -143,7 +141,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java index 0d29617eb76..2f2adf2bc93 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java @@ -88,13 +88,10 @@ private void runTransferTest(Types.ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -125,7 +122,7 @@ private void runTransferTest(Types.ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java index 63432e44658..837fb979fc6 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java @@ -108,12 +108,9 @@ private void runTriTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -151,7 +148,7 @@ private void runTriTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java index bd48a45924f..5e0953d0aa5 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java @@ -136,11 +136,10 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(testname); @@ -172,7 +171,7 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java index 6753774f653..3f866be9844 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java @@ -297,11 +297,10 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -334,7 +333,7 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java index ad21760d086..15a709a2057 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java @@ -161,11 +161,10 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -198,7 +197,7 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java index 7fac163c4a6..111c16ded84 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java @@ -151,11 +151,10 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -187,7 +186,7 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) { } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java index 8019edf83f8..610b5ebedf7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java @@ -160,11 +160,10 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2}); try { - if(!isAlive(t1, t2)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); getAndLoadTestConfiguration(test_name); @@ -197,7 +196,7 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); } finally { - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); resetExecMode(platform_old); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java index ba725519a8c..1088de53625 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java @@ -214,14 +214,11 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); // Run reference dml script with normal matrix @@ -280,7 +277,7 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java index 48c9cab632a..0993b41698f 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java @@ -134,13 +134,10 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); setExecMode(execMode); @@ -203,7 +200,7 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } @@ -246,13 +243,10 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -364,7 +358,7 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java index 4ba19fb3d0f..ebb42dfec86 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java @@ -133,14 +133,11 @@ public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrix int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -154,7 +151,7 @@ public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrix } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java index 037675c584c..189c109fa8b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java @@ -101,14 +101,11 @@ private void runAggregateOperationTest(ExecMode execMode) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -149,7 +146,7 @@ private void runAggregateOperationTest(ExecMode execMode) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); } } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java index d75924746c5..006b5d76883 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java @@ -203,13 +203,10 @@ private void runAggregateOperationTest(OpType type, ExecType instType) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -264,7 +261,7 @@ private void runAggregateOperationTest(OpType type, ExecType instType) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java index 5eb1179efd9..b20d949857e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java @@ -168,14 +168,11 @@ private void runCumOperationTest(OpType type, ExecType instType) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); @@ -245,7 +242,7 @@ private void runCumOperationTest(OpType type, ExecType instType) { } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(platformOld); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java index abf0c8c228f..375a4f95181 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java @@ -138,14 +138,11 @@ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean singleMa int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1); - Process t2 = startLocalFedWorker(port2); - Process t3 = startLocalFedWorker(port3); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -220,7 +217,7 @@ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean singleMa } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java index c37341ed681..73cb5470a38 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java @@ -138,14 +138,11 @@ private void runMMChainTest(ExecMode execMode, String TEST_NAME) { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); - Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); - Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); - Process t4 = startLocalFedWorker(port4); + Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}); try { - if(!isAlive(t1, t2, t3, t4)) + if(!isAlive(workers)) throw new RuntimeException("Failed starting federated worker"); rtplatform = execMode; @@ -181,7 +178,7 @@ private void runMMChainTest(ExecMode execMode, String TEST_NAME) { Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java index b70bdd2940b..fd1f88f2761 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java @@ -230,7 +230,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line default: throw new RuntimeException("Not supported type"); } - Thread t1 = null, t2 = null, t3 = null, t4 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(TEST_NAME1); @@ -239,10 +239,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); String[] otherargs = lineage ? new String[] {"-lineage", "reuse_full"} : null; - t1 = startLocalFedWorkerThread(port1, otherargs); - t2 = startLocalFedWorkerThread(port2, otherargs); - t3 = startLocalFedWorkerThread(port3, otherargs); - t4 = startLocalFedWorkerThread(port4, otherargs); + workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs); FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER, DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ? @@ -345,7 +342,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line throw new RuntimeException(ex); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(rtold); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java index f144e03984a..8c94e62692d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java @@ -127,7 +127,7 @@ public void runComplexRecodeTestBinarySparseCP() { private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.FileFormat format) { ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE); - Thread t1 = null, t2 = null, t3 = null, t4 = null; + Thread[] workers = null; try { getAndLoadTestConfiguration(TEST_NAME_RECODE); @@ -135,10 +135,7 @@ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types. int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); - t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); - t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); - t4 = startLocalFedWorkerThread(port4); + workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, null, FED_WORKER_WAIT); // schema Types.ValueType[] schema = new Types.ValueType[cols / 2]; @@ -205,7 +202,7 @@ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types. Assert.fail(ex.getMessage()); } finally { - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); resetExecMode(rtold); } } diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java index ca133caa76e..0c46cd68ea5 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java @@ -103,8 +103,7 @@ public void federatedReuse(String test) { int port2 = getRandomAvailablePort(); String[] otherargs = new String[] {"-lineage", "reuse_full"}; Lineage.resetInternalState(); - Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, otherargs); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT); TestConfiguration config = availableTestConfigurations.get(test); loadTestConfiguration(config); @@ -149,7 +148,7 @@ public void federatedReuse(String test) { } - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } } diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java index f7f01f6f9ef..0cf9d972719 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java @@ -108,10 +108,7 @@ private void runTriUDFReuse(ExecMode execMode) { int port4 = getRandomAvailablePort(); String[] otherargs = new String[] {"-lineage", "reuse_full"}; Lineage.resetInternalState(); - Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, otherargs, FED_WORKER_WAIT_S); - Thread t3 = startLocalFedWorkerThread(port3, otherargs, FED_WORKER_WAIT_S); - Thread t4 = startLocalFedWorkerThread(port4, otherargs); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs, FED_WORKER_WAIT); rtplatform = execMode; if(rtplatform == ExecMode.SPARK) { @@ -146,7 +143,7 @@ private void runTriUDFReuse(ExecMode execMode) { // assert reuse count Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0); - TestUtils.shutdownThreads(t1, t2, t3, t4); + TestUtils.shutdownThreads(workers); rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java index 15f6a7978e5..b8a66196670 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java @@ -93,8 +93,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); String[] otherargs = new String[] {"-lineage", "reuse_full"}; - Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); - Thread t2 = startLocalFedWorkerThread(port2, otherargs); + Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); @@ -136,7 +135,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE assertTrue(mmCount > mmCount_reuse); assertTrue(fed_mmCount > fed_mmCount_reuse); - TestUtils.shutdownThreads(t1, t2); + TestUtils.shutdownThreads(workers); } finally { resetExecMode(oldExec);