diff --git a/src/main/java/com/oltpbenchmark/benchmarks/tpcc/TPCCBenchmark.java b/src/main/java/com/oltpbenchmark/benchmarks/tpcc/TPCCBenchmark.java index 502b78b..f466001 100644 --- a/src/main/java/com/oltpbenchmark/benchmarks/tpcc/TPCCBenchmark.java +++ b/src/main/java/com/oltpbenchmark/benchmarks/tpcc/TPCCBenchmark.java @@ -86,16 +86,13 @@ private List createTerminalsOldWay() throws SQLException { // totalWarehouses is equal to numWarehouses in case of non-partitioned use case int totalWarehouses = (int) workConf.getScaleFactor(); - if (totalWarehouses <= 0) { - // At least one warehouse, @see - // https://github.com/cmu-db/benchbase/blob/main/src/main/java/com/oltpbenchmark/benchmarks/tpcc/TPCCBenchmark.java - totalWarehouses = 1; - } - - // Default values used for warehouse indexes and stride - final int startWarehouseIndex = 1; - final int endWarehouseIndex = totalWarehouses; - final int stride = 1; + // [startWarehouseIndex, endWarehouseIndex] are both included. + // Use defaults if not configured: start=1, end=totalWarehouses, stride=1 + final int startWarehouseIndex = + workConf.getStartWarehouseIndex() > 0 ? workConf.getStartWarehouseIndex() : 1; + final int endWarehouseIndex = + workConf.getEndWarehouseIndex() > 0 ? workConf.getEndWarehouseIndex() : totalWarehouses; + final int stride = workConf.getStride() > 0 ? workConf.getStride() : 1; LOG.info( "Start warehouse idx: {} end warehouse idx: {} stride: {}", @@ -110,6 +107,10 @@ private List createTerminalsOldWay() throws SQLException { final int numWarehouses = w_ids.size(); int numTerminals = workConf.getTerminals(); + assert startWarehouseIndex >= 1 : "The start index must be >= 1"; + assert endWarehouseIndex >= 1 : "The end index must be >= 1"; + assert endWarehouseIndex <= workConf.getScaleFactor() + : "The end index must be within the scale factor"; assert numWarehouses >= 1 : "At least need 1 warehouse to do benchmark"; // We distribute terminals evenly across the warehouses diff --git a/src/test/java/com/oltpbenchmark/benchmarks/tpcc/TestTPCCBenchmark.java b/src/test/java/com/oltpbenchmark/benchmarks/tpcc/TestTPCCBenchmark.java index 0d83c7c..6f75278 100644 --- a/src/test/java/com/oltpbenchmark/benchmarks/tpcc/TestTPCCBenchmark.java +++ b/src/test/java/com/oltpbenchmark/benchmarks/tpcc/TestTPCCBenchmark.java @@ -16,10 +16,20 @@ package com.oltpbenchmark.benchmarks.tpcc; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + import com.oltpbenchmark.api.AbstractTestBenchmarkModule; +import com.oltpbenchmark.api.BenchmarkModule; import com.oltpbenchmark.api.Procedure; +import com.oltpbenchmark.api.Worker; import com.oltpbenchmark.benchmarks.tpcc.procedures.*; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Test; public class TestTPCCBenchmark extends AbstractTestBenchmarkModule { @@ -35,4 +45,186 @@ public List> procedures() { public Class benchmarkClass() { return TPCCBenchmark.class; } + + /** + * Test that workers are created with correct warehouse IDs when using startWarehouseIndex, + * endWarehouseIndex, and stride parameters for distributed benchmarking. + */ + @Test + public void testCreateTerminalsWithStrideDistribution() throws Exception { + // Configure for distributed benchmarking: warehouses 3, 6, 9 (stride=3, start=3, end=9) + int scaleFactor = 10; + int startWarehouseIndex = 3; + int endWarehouseIndex = 9; + int stride = 3; + int terminals = 6; + + this.workConf.setScaleFactor(scaleFactor); + this.workConf.setStartWarehouseIndex(startWarehouseIndex); + this.workConf.setEndWarehouseIndex(endWarehouseIndex); + this.workConf.setStride(stride); + this.workConf.setTerminals(terminals); + + // Create workers + List> workers = this.benchmark.makeWorkers(); + + assertEquals("Should create correct number of workers", terminals, workers.size()); + + // Collect all warehouse IDs assigned to workers + Set assignedWarehouseIds = new HashSet<>(); + for (Worker worker : workers) { + TPCCWorker tpccWorker = (TPCCWorker) worker; + assignedWarehouseIds.add(tpccWorker.getTerminalWarehouseID()); + } + + // Expected warehouses: 3, 6, 9 (start=3, end=9, stride=3) + Set expectedWarehouseIds = Set.of(3, 6, 9); + + assertEquals( + "Workers should only be assigned to warehouses matching stride pattern", + expectedWarehouseIds, + assignedWarehouseIds); + } + + /** + * Test that terminals are evenly distributed across warehouses. With 10 terminals across 3 + * warehouses, distribution should be approximately 3, 3, 4 (or similar even split). + */ + @Test + public void testTerminalsEvenlyDistributedAcrossWarehouses() throws Exception { + int scaleFactor = 10; + int startWarehouseIndex = 1; + int endWarehouseIndex = 3; + int stride = 1; + int terminals = 10; + + this.workConf.setScaleFactor(scaleFactor); + this.workConf.setStartWarehouseIndex(startWarehouseIndex); + this.workConf.setEndWarehouseIndex(endWarehouseIndex); + this.workConf.setStride(stride); + this.workConf.setTerminals(terminals); + + List> workers = this.benchmark.makeWorkers(); + + assertEquals("Should create correct number of workers", terminals, workers.size()); + + // Count terminals per warehouse + Map terminalsPerWarehouse = new HashMap<>(); + for (Worker worker : workers) { + TPCCWorker tpccWorker = (TPCCWorker) worker; + int warehouseId = tpccWorker.getTerminalWarehouseID(); + terminalsPerWarehouse.merge(warehouseId, 1, Integer::sum); + } + + // With 10 terminals across 3 warehouses, expect distribution like 3, 3, 4 + // Each warehouse should have at least floor(10/3)=3 terminals + // and at most ceil(10/3)=4 terminals + int numWarehouses = 3; + int minTerminalsPerWarehouse = 3; + int maxTerminalsPerWarehouse = 4; + + for (Map.Entry entry : terminalsPerWarehouse.entrySet()) { + int warehouseId = entry.getKey(); + int count = entry.getValue(); + assertTrue( + "Warehouse " + + warehouseId + + " has " + + count + + " terminals, expected between " + + minTerminalsPerWarehouse + + " and " + + maxTerminalsPerWarehouse, + count >= minTerminalsPerWarehouse && count <= maxTerminalsPerWarehouse); + } + + // Verify all expected warehouses have terminals + assertEquals( + "All warehouses should have at least one terminal", + numWarehouses, + terminalsPerWarehouse.size()); + + // Verify total terminals assigned equals expected + int totalAssigned = terminalsPerWarehouse.values().stream().mapToInt(Integer::intValue).sum(); + assertEquals( + "Total terminals assigned should equal requested terminals", terminals, totalAssigned); + + // Verify each worker has sequential ID from 0 to terminals-1 + for (int i = 0; i < workers.size(); i++) { + assertEquals("Worker ID should be sequential", i, workers.get(i).getId()); + } + } + + /** + * Test distributed benchmarking scenario: 3 instances each handling different warehouse subsets. + * Instance 1: warehouses 1, 4, 7, 10 (start=1, stride=3) Instance 2: warehouses 2, 5, 8 (start=2, + * stride=3) Instance 3: warehouses 3, 6, 9 (start=3, stride=3) + */ + @Test + public void testDistributedBenchmarkingScenario() throws Exception { + int scaleFactor = 10; + int stride = 3; + int terminalsPerInstance = 8; + + // Simulate instance 1: warehouses 1, 4, 7, 10 + this.workConf.setScaleFactor(scaleFactor); + this.workConf.setStartWarehouseIndex(1); + this.workConf.setEndWarehouseIndex(10); + this.workConf.setStride(stride); + this.workConf.setTerminals(terminalsPerInstance); + + List> workers1 = this.benchmark.makeWorkers(); + + Set instance1Warehouses = new HashSet<>(); + for (Worker worker : workers1) { + instance1Warehouses.add(((TPCCWorker) worker).getTerminalWarehouseID()); + } + + // Instance 1 should only use warehouses 1, 4, 7, 10 + Set expectedInstance1 = Set.of(1, 4, 7, 10); + assertEquals( + "Instance 1 should use warehouses 1, 4, 7, 10", expectedInstance1, instance1Warehouses); + + // Simulate instance 2: warehouses 2, 5, 8 + this.workConf.setStartWarehouseIndex(2); + this.workConf.setEndWarehouseIndex(10); + this.benchmark = + this.benchmark + .getClass() + .getConstructor(this.workConf.getClass()) + .newInstance(this.workConf); + + List> workers2 = this.benchmark.makeWorkers(); + + Set instance2Warehouses = new HashSet<>(); + for (Worker worker : workers2) { + instance2Warehouses.add(((TPCCWorker) worker).getTerminalWarehouseID()); + } + + // Instance 2 should only use warehouses 2, 5, 8 + Set expectedInstance2 = Set.of(2, 5, 8); + assertEquals( + "Instance 2 should use warehouses 2, 5, 8", expectedInstance2, instance2Warehouses); + + // Simulate instance 3: warehouses 3, 6, 9 + this.workConf.setStartWarehouseIndex(3); + this.workConf.setEndWarehouseIndex(10); + this.benchmark = + this.benchmark + .getClass() + .getConstructor(this.workConf.getClass()) + .newInstance(this.workConf); + + List> workers3 = this.benchmark.makeWorkers(); + + Set instance3Warehouses = new HashSet<>(); + for (Worker worker : workers3) { + instance3Warehouses.add(((TPCCWorker) worker).getTerminalWarehouseID()); + } + + // Instance 3 should only use warehouses 3, 6, 9 + Set expectedInstance3 = Set.of(3, 6, 9); + assertEquals( + "Instance 3 should use warehouses 3, 6, 9", expectedInstance3, instance3Warehouses); + } }