diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index c7f62b02a2b..46b18cf3476 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1642,9 +1642,11 @@ protected Process startLocalFedWorker(int port){
/**
* Start a new JVM for a federated worker at the port.
- *
- * @param port Port to use for the JVM
- * @param sleep The sleep time to wait for the worker to start
+ *
+ * @param port Port to use for the JVM
+ * @param sleep Upper bound on the wait for the worker to become ready, in ms; raised to a
+ * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name retained
+ * for source compatibility with the previous fixed-sleep API.
* @return The process containing the worker
*/
protected Process startLocalFedWorker(int port, int sleep){
@@ -1665,18 +1667,65 @@ protected Process startLocalFedWorker(int port, String[] addArgs) {
/**
* Start new JVM for a federated worker at the port.
- *
- * @param port Port to use for the JVM
- * @param addArgs The arguments to add
- * @param sleep The time to wait for the process to start
+ *
+ *
Returns once the worker's TCP port accepts connections (the worker opens the port after
+ * Netty's bind completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses.
+ *
+ * @param port Port to use for the JVM
+ * @param addArgs The arguments to add
+ * @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a
+ * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name
+ * retained for source compatibility with the previous fixed-sleep API.
* @return the process associated with the worker.
*/
- protected static Process startLocalFedWorker(int port, String[] addArgs, int sleep) {
- Process process = null;
+ protected static Process startLocalFedWorker(int port, String[] addArgs, int timeoutMs) {
+ Process process = spawnLocalFedWorker(port, addArgs);
+ FederatedWorkerUtils.waitForWorker(process, port, timeoutMs);
+ return process;
+ }
+
+ /**
+ * Start N federated worker JVMs back to back, then wait for all of them to become ready in one
+ * shared poll loop. The wall-clock wait scales with the slowest worker rather than the sum of the
+ * per-worker waits.
+ *
+ * @param ports Ports to use, one per worker
+ * @return The process per port, in the same order as {@code ports}.
+ */
+ protected static Process[] startLocalFedWorkers(int[] ports) {
+ return startLocalFedWorkers(ports, null, FED_WORKER_WAIT);
+ }
+
+ /** @see #startLocalFedWorkers(int[], String[], int) */
+ protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs) {
+ return startLocalFedWorkers(ports, addArgs, FED_WORKER_WAIT);
+ }
+
+ /**
+ * Start N federated worker JVMs back to back, then wait for all of them to become ready in one
+ * shared poll loop.
+ *
+ * @param ports Ports to use, one per worker
+ * @param addArgs Extra worker CLI args (applied to every worker), or null
+ * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside
+ * {@link FederatedWorkerUtils}.
+ * @return The process per port, in the same order as {@code ports}.
+ */
+ protected static Process[] startLocalFedWorkers(int[] ports, String[] addArgs, int timeoutMs) {
+ Process[] processes = new Process[ports.length];
+ for(int i = 0; i < ports.length; i++) {
+ processes[i] = spawnLocalFedWorker(ports[i], addArgs);
+ }
+ FederatedWorkerUtils.waitForWorkers(processes, ports, timeoutMs);
+ return processes;
+ }
+
+ /** Spawn a federated worker JVM and return without waiting for the port to bind. */
+ private static Process spawnLocalFedWorker(int port, String[] addArgs) {
String separator = System.getProperty("file.separator");
String classpath = System.getProperty("java.class.path");
String path = System.getProperty("java.home") + separator + "bin" + separator + "java";
- String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m",
+ String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m", "-Xmn100m",
"--add-opens=java.base/java.nio=ALL-UNNAMED" ,
"--add-opens=java.base/java.io=ALL-UNNAMED" ,
"--add-opens=java.base/java.util=ALL-UNNAMED" ,
@@ -1701,19 +1750,14 @@ protected static Process startLocalFedWorker(int port, String[] addArgs, int sle
DMLScript.class.getName(), "-w", Integer.toString(port), "-stats"});
if(addArgs != null)
args = ArrayUtils.addAll(args, addArgs);
-
- ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO();
+ ProcessBuilder processBuilder = new ProcessBuilder(args).inheritIO();
try {
- process = processBuilder.start();
- // Give some time to startup the worker.
- sleep(sleep);
+ return processBuilder.start();
}
- catch(IOException | InterruptedException e) {
- e.printStackTrace();
+ catch(IOException e) {
+ throw new RuntimeException("Failed to launch federated worker process on port " + port, e);
}
- isAlive(process);
- return process;
}
/**
@@ -1743,7 +1787,7 @@ protected Process startLocalFedMonitoring(int port, String[] addArgs) {
}
/**
- * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
+ * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.
*
* Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
@@ -1769,12 +1813,14 @@ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs) {
}
/**
- * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
+ * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.
*
* Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
* @param port Port to use
- * @param sleep The amount of time to wait for the worker startup. in Milliseconds
+ * @param sleep Upper bound on the wait for the worker to become ready, in ms; raised to a minimum
+ * value enforced inside {@link FederatedWorkerUtils}. Parameter name retained for source
+ * compatibility with the previous fixed-sleep API.
* @return The thread associated with the worker.
*/
public static Thread startLocalFedWorkerThread(int port, int sleep) {
@@ -1782,50 +1828,99 @@ public static Thread startLocalFedWorkerThread(int port, int sleep) {
}
/**
- * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
- *
- * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
- *
+ * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.
+ *
+ *
Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is
+ * disabled.
+ *
+ *
Returns once the worker's TCP port accepts connections (the worker opens the port after Netty's bind
+ * completes), or throws a {@link RuntimeException} after {@code timeoutMs} elapses.
+ *
* @param port Port to use
* @param otherArgs The command line arguments to start the worker with
- * @param sleep The amount of time to wait for the worker startup. in Milliseconds
+ * @param timeoutMs Upper bound on the wait for the worker to become ready, in ms; raised to a
+ * minimum value enforced inside {@link FederatedWorkerUtils}. Parameter name
+ * retained for source compatibility with the previous fixed-sleep API.
* @return The thread associated with the worker.
*/
- public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) {
+ public static Thread startLocalFedWorkerThread(int port, String[] otherArgs, int timeoutMs) {
+ Thread t = spawnLocalFedWorkerThread(port, otherArgs);
+ FederatedWorkerUtils.waitForWorker(t, port, timeoutMs);
+ return t;
+ }
+
+ /**
+ * Start N federated worker threads in the same JVM back to back, then wait for all of them to
+ * become ready in one shared poll loop. The wall-clock wait scales with the slowest worker rather
+ * than the sum of the per-worker waits.
+ *
+ * @param ports Ports to use, one per worker
+ * @return The thread per port, in the same order as {@code ports}.
+ */
+ public static Thread[] startLocalFedWorkerThreads(int[] ports) {
+ return startLocalFedWorkerThreads(ports, null, FED_WORKER_WAIT);
+ }
+ /** @see #startLocalFedWorkerThreads(int[], String[], int) */
+ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs) {
+ return startLocalFedWorkerThreads(ports, otherArgs, FED_WORKER_WAIT);
+ }
+
+ /**
+ * Start N federated worker threads in the same JVM back to back, then wait for all of them to
+ * become ready in one shared poll loop.
+ *
+ * @param ports Ports to use, one per worker
+ * @param otherArgs Extra worker CLI args (applied to every worker), or null
+ * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum value enforced inside
+ * {@link FederatedWorkerUtils}.
+ * @return The thread per port, in the same order as {@code ports}.
+ */
+ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[] otherArgs, int timeoutMs) {
+ Thread[] threads = new Thread[ports.length];
+ for(int i = 0; i < ports.length; i++) {
+ threads[i] = spawnLocalFedWorkerThread(ports[i], otherArgs);
+ // Sleep THREAD_SPAWN_STAGGER_MS between in-JVM thread spawns to reduce contention on
+ // shared static initialization in DMLScript / FederatedWorker (e.g. LineageCacheConfig
+ // setters) when multiple worker threads enter main() concurrently.
+ if(i + 1 < ports.length) {
+ try {
+ java.util.concurrent.TimeUnit.MILLISECONDS.sleep(THREAD_SPAWN_STAGGER_MS);
+ }
+ catch(InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Interrupted while spawning federated worker threads", e);
+ }
+ }
+ }
+ FederatedWorkerUtils.waitForWorkers(threads, ports, timeoutMs);
+ return threads;
+ }
+
+ private static final int THREAD_SPAWN_STAGGER_MS = 25;
+
+ /** Spawn a federated worker thread in this JVM and return without waiting for the port to bind. */
+ private static Thread spawnLocalFedWorkerThread(int port, String[] otherArgs) {
ArrayList args = new ArrayList<>();
-
args.add("-w");
args.add(Integer.toString(port));
-
if(otherArgs != null)
- for( String s : otherArgs)
+ for(String s : otherArgs)
args.add(s);
String[] finalArguments = args.toArray(new String[args.size()]);
Statistics.allowWorkerStatistics = false;
- try {
- Thread t = new Thread(() -> {
- try {
- main(finalArguments);
- }
- catch(Exception e) {
- LOG.error("Exception in startup of federated worker", e);
- }
- });
- t.start();
- java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep);
- if(!t.isAlive())
- throw new RuntimeException("Failed starting federated worker");
- return t;
- }
- catch(InterruptedException e) {
- e.printStackTrace();
- fail("Failed to start federated worker : " + e);
- // should never happen
- return null;
- }
+ Thread t = new Thread(() -> {
+ try {
+ main(finalArguments);
+ }
+ catch(Exception e) {
+ LOG.error("Exception in startup of federated worker", e);
+ }
+ });
+ t.start();
+ return t;
}
public static boolean isAlive(Thread... threads){
@@ -1846,28 +1941,44 @@ public static boolean isAlive(Process... processes) {
/**
* Start java worker in same JVM.
- *
+ *
+ * Returns once the worker's TCP port accepts connections (the worker opens the port after
+ * Netty's bind completes), or throws a {@link RuntimeException} after the default federated worker
+ * timeout elapses. The port is extracted from {@code args}, which must contain {@code "-w" }.
+ *
* @param args the command line arguments
- * @return the thread associated with the process.s
+ * @return the thread associated with the worker.
*/
public static Thread startLocalFedWorkerWithArgs(String[] args) {
- Thread t = null;
+ final int port = extractWorkerPort(args);
+ Thread t = new Thread(() -> {
+ try {
+ main(args);
+ }
+ catch(IOException e) {
+ // IOException from main() is not propagated here; startup failures surface as a
+ // readiness timeout from waitForWorker below.
+ }
+ });
+ t.start();
+ FederatedWorkerUtils.waitForWorker(t, port, FED_WORKER_WAIT);
+ return t;
+ }
- try {
- t = new Thread(() -> {
+ private static int extractWorkerPort(String[] args) {
+ for(int i = 0; i < args.length - 1; i++) {
+ if("-w".equals(args[i])) {
try {
- main(args);
+ return Integer.parseInt(args[i + 1]);
}
- catch(IOException e) {
+ catch(NumberFormatException e) {
+ throw new IllegalArgumentException(
+ "Federated worker args contain non-numeric port after -w: " + args[i + 1], e);
}
- });
- t.start();
- java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT);
- }
- catch(InterruptedException e) {
- // Should happen at closing of the worker so don't print
+ }
}
- return t;
+ throw new IllegalArgumentException("Federated worker args must contain '-w ': "
+ + Arrays.toString(args));
}
private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) {
diff --git a/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java
new file mode 100644
index 00000000000..bd12f7c3bdc
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.util.function.BooleanSupplier;
+
+/**
+ * Test helpers that block until a federated worker is accepting TCP connections on its port.
+ *
+ * The federated worker opens its TCP port after Netty's {@code bind().sync()} returns; a successful
+ * TCP connect to that port therefore indicates that the worker is ready to accept requests. The methods
+ * here poll for that signal and throw {@link RuntimeException} on timeout or if the underlying
+ * {@code Process}/{@code Thread} exits before the port becomes ready.
+ */
+public final class FederatedWorkerUtils {
+
+ /** Sleep between successive poll rounds, in milliseconds. */
+ private static final int POLL_INTERVAL_MS = 25;
+
+ /** Per-attempt {@link Socket#connect} timeout, in milliseconds. */
+ private static final int CONNECT_TIMEOUT_MS = 25;
+
+ /**
+ * Minimum value applied to the caller-supplied {@code timeoutMs}. The wait returns as soon as the
+ * worker accepts a connection, so this only affects the upper bound used when a worker never becomes
+ * ready. Set to 3s so a worker that fails to bind is reported quickly while still leaving room for
+ * Netty bind under typical local/CI conditions.
+ */
+ private static final int MIN_TIMEOUT_MS = 3_000;
+
+ private FederatedWorkerUtils() {
+ // utility class
+ }
+
+ /**
+ * Block until a federated worker is accepting TCP connections on {@code port}, or throw a
+ * {@link RuntimeException} after the effective timeout elapses.
+ *
+ * @param port port the federated worker is expected to bind
+ * @param timeoutMs upper bound on the wait, in ms; raised to {@link #MIN_TIMEOUT_MS} if smaller
+ */
+ public static void waitForWorker(int port, int timeoutMs) {
+ waitForWorker(port, timeoutMs, () -> true, "worker");
+ }
+
+ /**
+ * Block until a federated worker is accepting TCP connections on {@code port}. Returns early with
+ * a {@link RuntimeException} if {@code aliveCheck} reports the worker is no longer alive.
+ */
+ public static void waitForWorker(int port, int timeoutMs, BooleanSupplier aliveCheck, String workerKind) {
+ final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS);
+ final long deadline = System.currentTimeMillis() + effectiveTimeout;
+ while(System.currentTimeMillis() < deadline) {
+ if(!aliveCheck.getAsBoolean()) {
+ throw new RuntimeException(
+ "Federated " + workerKind + " on port " + port + " died before becoming ready.");
+ }
+ if(tryConnect(port)) {
+ return;
+ }
+ sleepQuietly();
+ }
+ throw new RuntimeException("Federated " + workerKind + " on port " + port
+ + " did not become ready within " + effectiveTimeout + "ms.");
+ }
+
+ /** Overload that also returns early if the given worker process exits before the port is ready. */
+ public static void waitForWorker(Process process, int port, int timeoutMs) {
+ waitForWorker(port, timeoutMs, process::isAlive, "worker process");
+ }
+
+ /** Overload that also returns early if the given worker thread exits before the port is ready. */
+ public static void waitForWorker(Thread thread, int port, int timeoutMs) {
+ waitForWorker(port, timeoutMs, thread::isAlive, "worker thread");
+ }
+
+ /**
+ * Block until every listed federated worker is accepting TCP connections. All ports are polled in
+ * one shared loop, so the wall-clock wait is bounded by the slowest worker rather than the sum of
+ * individual waits.
+ *
+ * @param ports ports the workers are expected to bind
+ * @param timeoutMs upper bound on the wait, in ms; raised to {@link #MIN_TIMEOUT_MS} if smaller
+ */
+ public static void waitForWorkers(int[] ports, int timeoutMs) {
+ waitForWorkers(ports, timeoutMs, i -> true, "workers");
+ }
+
+ /**
+ * Overload that also returns early if any of the worker processes exits before its port is ready.
+ *
+ * @throws IllegalArgumentException if {@code processes.length != ports.length}
+ */
+ public static void waitForWorkers(Process[] processes, int[] ports, int timeoutMs) {
+ if(processes.length != ports.length) {
+ throw new IllegalArgumentException(
+ "processes/ports length mismatch: " + processes.length + " vs " + ports.length);
+ }
+ waitForWorkers(ports, timeoutMs, i -> processes[i].isAlive(), "worker processes");
+ }
+
+ /**
+ * Overload that also returns early if any of the worker threads exits before its port is ready.
+ *
+ * @throws IllegalArgumentException if {@code threads.length != ports.length}
+ */
+ public static void waitForWorkers(Thread[] threads, int[] ports, int timeoutMs) {
+ if(threads.length != ports.length) {
+ throw new IllegalArgumentException(
+ "threads/ports length mismatch: " + threads.length + " vs " + ports.length);
+ }
+ waitForWorkers(ports, timeoutMs, i -> threads[i].isAlive(), "worker threads");
+ }
+
+ /**
+ * Bulk variant taking a per-index liveness predicate so callers can plug in either {@code Process}
+ * or {@code Thread} liveness. Each port flips to ready as soon as it accepts a connection; the loop
+ * sleeps once per round (after probing all pending ports) rather than between individual port
+ * probes.
+ */
+ public static void waitForWorkers(int[] ports, int timeoutMs, java.util.function.IntPredicate aliveCheck,
+ String workerKind) {
+ final int effectiveTimeout = Math.max(timeoutMs, MIN_TIMEOUT_MS);
+ final long deadline = System.currentTimeMillis() + effectiveTimeout;
+ final boolean[] ready = new boolean[ports.length];
+ int remaining = ports.length;
+ while(remaining > 0 && System.currentTimeMillis() < deadline) {
+ boolean madeProgress = false;
+ for(int i = 0; i < ports.length; i++) {
+ if(ready[i]) {
+ continue;
+ }
+ if(!aliveCheck.test(i)) {
+ throw new RuntimeException("Federated " + workerKind + " on port " + ports[i]
+ + " died before becoming ready.");
+ }
+ if(tryConnect(ports[i])) {
+ ready[i] = true;
+ remaining--;
+ madeProgress = true;
+ }
+ }
+ if(remaining > 0 && !madeProgress) {
+ sleepQuietly();
+ }
+ }
+ if(remaining > 0) {
+ StringBuilder sb = new StringBuilder("Federated ").append(workerKind)
+ .append(" did not all become ready within ").append(effectiveTimeout).append("ms. Pending ports:");
+ for(int i = 0; i < ports.length; i++) {
+ if(!ready[i]) {
+ sb.append(' ').append(ports[i]);
+ }
+ }
+ throw new RuntimeException(sb.toString());
+ }
+ }
+
+ private static boolean tryConnect(int port) {
+ try(Socket s = new Socket()) {
+ s.connect(new InetSocketAddress("localhost", port), CONNECT_TIMEOUT_MS);
+ return true;
+ }
+ catch(IOException e) {
+ return false;
+ }
+ }
+
+ private static void sleepQuietly() {
+ try {
+ Thread.sleep(POLL_INTERVAL_MS);
+ }
+ catch(InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Interrupted while waiting for federated worker", ie);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
index 5e880166afe..3b53c71c0ea 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
@@ -120,8 +120,7 @@ public void federatedAlsCG(String testname, ExecMode execMode)
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(testname);
@@ -153,7 +152,7 @@ public void federatedAlsCG(String testname, ExecMode execMode)
HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_!="));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
index f2f35700119..0537ecf18de 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
@@ -114,10 +114,7 @@ public void federatedL2SVM(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -139,7 +136,7 @@ public void federatedL2SVM(Types.ExecMode execMode) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
index 73bf8e91de6..2c81da6333e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
@@ -102,10 +102,7 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -140,7 +137,7 @@ private void runAggregateOperationTest(ExecMode execMode) {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 6d8e8165305..b54cb3d0a64 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -100,8 +100,7 @@ public void federatedGLM(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -122,7 +121,7 @@ public void federatedGLM(Types.ExecMode execMode) {
// compare via files
compareResults(1e-2);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index c8605ac3d88..c1fc83ae3c3 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -113,8 +113,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -155,7 +154,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) {
// compare via files
// compareResults(1e-9); --> randomized
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
index f7040e6a4dc..0b38317f73b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
@@ -104,8 +104,7 @@ public void federatedL2SVM(Types.ExecMode execMode, boolean singleWorker) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -126,7 +125,7 @@ public void federatedL2SVM(Types.ExecMode execMode, boolean singleWorker) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
index 8b02246e0aa..f3e921dd935 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
@@ -109,10 +109,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -134,7 +131,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
// compare via files
compareResults(1e-2);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check correct federated operations
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index a3e91ef37d7..74882637d95 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -95,8 +95,7 @@ public void federatedLogReg(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -117,7 +116,7 @@ public void federatedLogReg(Types.ExecMode execMode) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue("contains fed_ba+*", heavyHittersContainsString("fed_ba+*"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
index c5344f9d847..97a4d84a0b9 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
@@ -97,8 +97,7 @@ public void federatedMSVM(Types.ExecMode execMode, boolean singleWorker) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -119,7 +118,7 @@ public void federatedMSVM(Types.ExecMode execMode, boolean singleWorker) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
index 2a4186680c5..9ddd59483ce 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
@@ -102,10 +102,7 @@ public void federatedPCA(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -127,7 +124,7 @@ public void federatedPCA(Types.ExecMode execMode) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
index 19fb72ce47b..183f9900a6d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
@@ -113,8 +113,7 @@ public void federatedPNMF(ExecMode execMode)
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -141,7 +140,7 @@ public void federatedPNMF(ExecMode execMode)
HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
index a4a8236e3a8..21f2e392250 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
@@ -100,10 +100,7 @@ public void federatedL2SVM(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4});
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -126,7 +123,7 @@ public void federatedL2SVM(Types.ExecMode execMode) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_uacmax"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
index db071220d00..445abfa1485 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
@@ -115,10 +115,7 @@ private void runAggregateOperationTest(ExecMode execMode, boolean federatedCompi
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4});
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK)
@@ -157,7 +154,7 @@ private void runAggregateOperationTest(ExecMode execMode, boolean federatedCompi
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
OptimizerUtils.FEDERATED_COMPILATION = false;
rtplatform = platformOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index b8eef26f35a..a87328a8aa4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -117,8 +117,7 @@ public void federatedL2SVM(Types.ExecMode execMode, String testName) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
TestConfiguration config = availableTestConfigurations.get(testName);
loadTestConfiguration(config);
@@ -140,7 +139,7 @@ public void federatedL2SVM(Types.ExecMode execMode, String testName) {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
index 2f6e5e46173..94fab208e09 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -158,8 +158,7 @@ private void testFederatedCodegen(ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -188,7 +187,7 @@ private void testFederatedCodegen(ExecMode exec_mode) {
HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
index 61722dbc467..a23219e1587 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -194,8 +194,7 @@ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -228,7 +227,7 @@ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) {
HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
if(test_num >= 0 && test_num < 100)
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
index 33a551d28d0..74cd2c0a420 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
@@ -143,8 +143,7 @@ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -173,7 +172,7 @@ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) {
HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
index cef5fd5e99f..dd08ae58636 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -146,8 +146,7 @@ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -176,7 +175,7 @@ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) {
HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index b47a718c0ca..185c15557d1 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -147,8 +147,7 @@ private void testFederatedCodegenRowwise(ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -178,7 +177,7 @@ private void testFederatedCodegenRowwise(ExecMode exec_mode) {
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
if(row_partitioned)
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
index 53dce3f01c7..bcc2b61adf2 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
@@ -118,7 +118,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) {
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(testName);
@@ -128,8 +128,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -164,7 +163,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName) {
+ Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
index 9a9ff18d28b..b6c239854b3 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
@@ -125,7 +125,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(testName);
@@ -135,8 +135,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -158,7 +157,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri
// fail("The following expected heavy hitters are missing: "
// + Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
} finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
index 91143396208..39ab4901480 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
@@ -170,7 +170,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(testName);
@@ -180,8 +180,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -205,7 +204,7 @@ private void loadAndRunTest(String[] expectedHeavyHitters, String testName, Stri
// + Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
index 5b54f14d059..0a2bd230da4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
@@ -240,7 +240,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName,
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try{
getAndLoadTestConfiguration(testName);
@@ -250,8 +250,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName,
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -275,7 +274,7 @@ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName,
fail("The following expected heavy hitters are missing: "
+ Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
} finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index ff96ad8af2f..566d3ff323c 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -91,8 +91,7 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
String host = "localhost";
try {
@@ -137,6 +136,6 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) {
resetExecMode(oldPlatform);
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
index 5d4d2c4e9c3..5f5c09e07cb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
@@ -96,8 +96,7 @@ public void federatedRead(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
String host = "localhost";
@@ -134,7 +133,7 @@ public void federatedRead(Types.ExecMode execMode) {
resetExecMode(oldPlatform);
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
/**
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
index 42165d6b68f..f2654b9bb66 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
@@ -102,8 +102,7 @@ public void federatedGet(ExecMode execMode, double sparsity) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
getAndLoadTestConfiguration(TEST_NAME);
@@ -136,7 +135,7 @@ public void federatedGet(ExecMode execMode, double sparsity) {
compareNNZ(refNNZ, fedNNZ);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
index d8bb7431472..68fb3a7da63 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
@@ -83,8 +83,7 @@ public void federatedWrite(ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, null, FED_WORKER_WAIT);
try {
@@ -122,6 +121,6 @@ public void federatedWrite(ExecMode execMode) {
resetExecMode(oldPlatform);
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
index c3a4756a2d5..866bbe00139 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
@@ -38,8 +38,6 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.junit.After;
-import com.google.crypto.tink.subtle.Random;
-
public abstract class MultiTenantTestBase extends AutomatedTestBase {
protected static final Log LOG = LogFactory.getLog(MultiTenantTestBase.class.getName());
@@ -63,7 +61,8 @@ protected int[] startFedWorkers(int numFedWorkers) {
}
/**
- * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses
+ * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses.
+ * Workers are spawned together and their port-bind is awaited in parallel.
*
* @param numFedWorkers the number of federated workers to start
* @return int[] the ports of the created federated workers
@@ -72,10 +71,10 @@ protected int[] startFedWorkers(int numFedWorkers, String[] addArgs) {
int[] ports = new int[numFedWorkers];
for(int counter = 0; counter < numFedWorkers; counter++) {
ports[counter] = getRandomAvailablePort();
- // start process but only wait long for last one.
- Process tmpProcess = startLocalFedWorker(ports[counter], addArgs,
- counter == numFedWorkers - 1 ? (FED_WORKER_WAIT + Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S);
- workerProcesses.add(tmpProcess);
+ }
+ Process[] processes = startLocalFedWorkers(ports, addArgs);
+ for(Process p : processes) {
+ workerProcesses.add(p);
}
return ports;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
index 8991d28194f..20945303b0b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
@@ -105,11 +105,10 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -134,7 +133,7 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
index e038d9efda0..d47a7efcfa4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
@@ -97,11 +97,10 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -127,7 +126,7 @@ public void federatedMultiply(Types.ExecMode execMode, boolean federatedCompilat
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
index 434eb08e262..1a88f7a0d39 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
@@ -90,11 +90,10 @@ public void federatedBroadcast(Types.ExecMode execMode) {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -117,7 +116,7 @@ public void federatedBroadcast(Types.ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
index a7bc0d80640..392e53aa77b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
@@ -97,12 +97,10 @@ public void federatedMultiply(Types.ExecMode execMode) {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
-
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -127,7 +125,7 @@ public void federatedMultiply(Types.ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
index ccfd4a6c430..832ae98d8ed 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
@@ -122,10 +122,9 @@ public void federatedMultiply(Types.ExecMode execMode) {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -152,7 +151,7 @@ public void federatedMultiply(Types.ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
index 1d9f951e78a..b52a1d60262 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
@@ -118,14 +118,10 @@ public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted)
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT + 1000);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4}, null, FED_WORKER_WAIT + 1000);
-
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// reference file should not be written to hdfs, so we set platform here
@@ -187,7 +183,7 @@ public void federatedCentralMoment(Types.ExecMode execMode, boolean isWeighted)
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
index f52a74bd4fc..5254378cf27 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
@@ -174,13 +174,10 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -244,7 +241,7 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
index 30d6b8fc0ec..6f9744ded1a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
@@ -165,13 +165,10 @@ else if(dataType == DataType.FRAME)
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -212,7 +209,7 @@ else if(dataType == DataType.FRAME)
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
}
}
finally {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
index 8fb8a80663e..06f86dd61b4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
@@ -220,14 +220,10 @@ private void runMisAlignedTest(OpType type, ExecMode execMode, MisAlignmentType
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
-
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// Run reference dml script with normal matrix
@@ -281,7 +277,7 @@ private void runMisAlignedTest(OpType type, ExecMode execMode, MisAlignmentType
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
index 05e4954af15..cfdac1932ff 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
@@ -114,12 +114,11 @@ private void federatedMultiply(Types.ExecMode execMode, boolean federatedCompila
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -145,7 +144,7 @@ private void federatedMultiply(Types.ExecMode execMode, boolean federatedCompila
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
index 530a8b7b55e..2447143de2b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
@@ -108,14 +108,11 @@ private void runProdTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -154,7 +151,7 @@ private void runProdTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
index 25d460526e9..6f13fce81b2 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
@@ -164,12 +164,13 @@ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double
port2 = getRandomAvailablePort();
port3 = getRandomAvailablePort();
port4 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
+ t1 = workers[0];
+ t2 = workers[1];
+ t3 = workers[2];
+ t4 = workers[3];
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
programArgs1 = new String[] {"-explain", "-stats", "100", "-args", String.valueOf(p), expected("S"),
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
index dadbd6b5905..ebf4708b3e9 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
@@ -132,12 +132,13 @@ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double
port2 = getRandomAvailablePort();
port3 = getRandomAvailablePort();
port4 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
+ t1 = workers[0];
+ t2 = workers[1];
+ t3 = workers[2];
+ t4 = workers[3];
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
programArgs1 = new String[] {"-explain", "-stats", "100", "-args", String.valueOf(p), expected("S"),
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
index d8ef8ca4c20..be69fac747d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
@@ -109,14 +109,11 @@ public void federatedRCBind(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// we need the reference file to not be written to hdfs, so we get the correct format
@@ -158,7 +155,7 @@ public void federatedRCBind(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
index 46bf7a45650..ce6aba438e0 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
@@ -112,14 +112,11 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
ProgramBlock.CHECK_MATRIX_PROPERTIES = true;
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// reference file should not be written to hdfs, so we set platform here
@@ -159,7 +156,7 @@ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilati
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
index 71f7c583665..0714a6e6fb7 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
@@ -121,14 +121,11 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -166,7 +163,7 @@ private void runAggregateOperationTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
index d2ec98a4a8d..c5347d6276a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
@@ -108,13 +108,10 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -153,7 +150,7 @@ private void runAggregateOperationTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
index 5d6887c3e2d..fb8bc619ec5 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
@@ -101,14 +101,11 @@ public void federatedReshape(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// reference file should not be written to hdfs, so we set platform here
rtplatform = execMode;
@@ -146,7 +143,7 @@ public void federatedReshape(Types.ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
index e4b7ed5e24a..c4f30a574b0 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
@@ -129,14 +129,11 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -176,7 +173,7 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
index b8fe21ef52d..7a72a868b6e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
@@ -178,13 +178,10 @@ private void runAggregateOperationTest(IndexType indexType, DataType dataType, E
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -236,7 +233,7 @@ private void runAggregateOperationTest(IndexType indexType, DataType dataType, E
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
index f242710338d..aa367ea7d4d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
@@ -131,14 +131,11 @@ private void runRollTest(ExecMode execMode, boolean activateFedCompilation) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if (!isAlive(t1, t2, t3, t4))
+ if (!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
if (rtplatform == ExecMode.SPARK) {
@@ -177,7 +174,7 @@ private void runRollTest(ExecMode execMode, boolean activateFedCompilation) {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
} finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
index 25b02ca1537..37b6f8826e4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
@@ -108,13 +108,10 @@ private void runRowIndexTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -155,7 +152,7 @@ private void runRowIndexTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
index fc9e4a73ba5..462181e75ea 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
@@ -106,11 +106,10 @@ public void federatedSplit(Types.ExecMode execMode) {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// Run reference dml script with normal matrix
@@ -141,7 +140,7 @@ else if(execMode != Types.ExecMode.SPARK) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
index ac9706a44ae..dec5fac4432 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
@@ -96,10 +96,9 @@ public void federatedLogReg(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -134,7 +133,7 @@ public void federatedLogReg(Types.ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
index f8acc4623a7..84b632c5e2c 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
@@ -91,12 +91,10 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3});
try {
- if(!isAlive(t1, t2, t3))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(false, DataExpression.DEFAULT_DELIM_DELIMITER,
@@ -143,7 +141,7 @@ private void runAggregateOperationTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
index 0d29617eb76..2f2adf2bc93 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
@@ -88,13 +88,10 @@ private void runTransferTest(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -125,7 +122,7 @@ private void runTransferTest(Types.ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
index 63432e44658..837fb979fc6 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
@@ -108,12 +108,9 @@ private void runTriTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -151,7 +148,7 @@ private void runTriTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
index bd48a45924f..5e0953d0aa5 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
@@ -136,11 +136,10 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
getAndLoadTestConfiguration(testname);
@@ -172,7 +171,7 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
index 6753774f653..3f866be9844 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
@@ -297,11 +297,10 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode)
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
getAndLoadTestConfiguration(test_name);
@@ -334,7 +333,7 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode)
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
index ad21760d086..15a709a2057 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
@@ -161,11 +161,10 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
getAndLoadTestConfiguration(test_name);
@@ -198,7 +197,7 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
index 7fac163c4a6..111c16ded84 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
@@ -151,11 +151,10 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
getAndLoadTestConfiguration(test_name);
@@ -187,7 +186,7 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
index 8019edf83f8..610b5ebedf7 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
@@ -160,11 +160,10 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
getAndLoadTestConfiguration(test_name);
@@ -197,7 +196,7 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
index ba725519a8c..1088de53625 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
@@ -214,14 +214,11 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
// Run reference dml script with normal matrix
@@ -280,7 +277,7 @@ private void runAggregateOperationTest(OpType type, ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 48c9cab632a..0993b41698f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -134,13 +134,10 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
setExecMode(execMode);
@@ -203,7 +200,7 @@ private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
@@ -246,13 +243,10 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput,
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -364,7 +358,7 @@ private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput,
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
index 4ba19fb3d0f..ebb42dfec86 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
@@ -133,14 +133,11 @@ public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrix
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -154,7 +151,7 @@ public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrix
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
index 037675c584c..189c109fa8b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
@@ -101,14 +101,11 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -149,7 +146,7 @@ private void runAggregateOperationTest(ExecMode execMode) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
}
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
index d75924746c5..006b5d76883 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
@@ -203,13 +203,10 @@ private void runAggregateOperationTest(OpType type, ExecType instType) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -264,7 +261,7 @@ private void runAggregateOperationTest(OpType type, ExecType instType) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
index 5eb1179efd9..b20d949857e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
@@ -168,14 +168,11 @@ private void runCumOperationTest(OpType type, ExecType instType) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
@@ -245,7 +242,7 @@ private void runCumOperationTest(OpType type, ExecType instType) {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
index abf0c8c228f..375a4f95181 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
@@ -138,14 +138,11 @@ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean singleMa
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
- Process t3 = startLocalFedWorker(port3);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -220,7 +217,7 @@ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean singleMa
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
index c37341ed681..73cb5470a38 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
@@ -138,14 +138,11 @@ private void runMMChainTest(ExecMode execMode, String TEST_NAME) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting federated worker");
rtplatform = execMode;
@@ -181,7 +178,7 @@ private void runMMChainTest(ExecMode execMode, String TEST_NAME) {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index b70bdd2940b..fd1f88f2761 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -230,7 +230,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line
default: throw new RuntimeException("Not supported type");
}
- Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(TEST_NAME1);
@@ -239,10 +239,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
String[] otherargs = lineage ? new String[] {"-lineage", "reuse_full"} : null;
- t1 = startLocalFedWorkerThread(port1, otherargs);
- t2 = startLocalFedWorkerThread(port2, otherargs);
- t3 = startLocalFedWorkerThread(port3, otherargs);
- t4 = startLocalFedWorkerThread(port4, otherargs);
+ workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs);
FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ?
@@ -345,7 +342,7 @@ private void runTransformTest(TransformType type, boolean colnames, boolean line
throw new RuntimeException(ex);
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(rtold);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
index f144e03984a..8c94e62692d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
@@ -127,7 +127,7 @@ public void runComplexRecodeTestBinarySparseCP() {
private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.FileFormat format) {
ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE);
- Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(TEST_NAME_RECODE);
@@ -135,10 +135,7 @@ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- t4 = startLocalFedWorkerThread(port4);
+ workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, null, FED_WORKER_WAIT);
// schema
Types.ValueType[] schema = new Types.ValueType[cols / 2];
@@ -205,7 +202,7 @@ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.
Assert.fail(ex.getMessage());
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(rtold);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
index ca133caa76e..0c46cd68ea5 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -103,8 +103,7 @@ public void federatedReuse(String test) {
int port2 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage", "reuse_full"};
Lineage.resetInternalState();
- Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT);
TestConfiguration config = availableTestConfigurations.get(test);
loadTestConfiguration(config);
@@ -149,7 +148,7 @@ public void federatedReuse(String test) {
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
index f7f01f6f9ef..0cf9d972719 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
@@ -108,10 +108,7 @@ private void runTriUDFReuse(ExecMode execMode) {
int port4 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage", "reuse_full"};
Lineage.resetInternalState();
- Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, otherargs, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, otherargs, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4, otherargs);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs, FED_WORKER_WAIT);
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -146,7 +143,7 @@ private void runTriUDFReuse(ExecMode execMode) {
// assert reuse count
Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
index 15f6a7978e5..b8a66196670 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
@@ -93,8 +93,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage", "reuse_full"};
- Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -136,7 +135,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
assertTrue(mmCount > mmCount_reuse);
assertTrue(fed_mmCount > fed_mmCount_reuse);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
finally {
resetExecMode(oldExec);