From 672f06c7da81b7984c7f386d659d6bec52382d80 Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Mon, 18 May 2026 14:58:10 +0200 Subject: [PATCH] Parallelize PQRetrainer training-vector extraction across sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PQRetrainer.extractVectorsSequential read every training sample with a single-threaded blocking getVectorInto() loop. Against remote storage each read is a network round-trip with no OS read-ahead, so thousands of samples serialize into thousands of round-trips — observed as a 2+ hour stall during a 53-segment compaction (HerdDB issue #587). Extract each source on its own thread/View (one RandomAccessReader per View, never shared) so up to jvector.pq.retrain.io.threads (default 16) remote reads are in flight at once; within a source reads stay ascending for read-ahead friendliness. Also close the previously-leaked Views and emit periodic progress logs so a slow extraction is distinguishable from a hang. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../jvector/graph/disk/PQRetrainer.java | 130 +++++++++++++++--- .../disk/TestOnDiskGraphIndexCompactor.java | 95 +++++++++++++ 2 files changed, 209 insertions(+), 16 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java index a0438168e..86a93c584 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java @@ -29,10 +29,19 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; /** * Handles Product Quantization retraining for graph index compaction. @@ -48,6 +57,18 @@ public class PQRetrainer { // avoiding the random I/O that would happen with per-node random sampling. private static final int SAMPLE_CHUNK_SIZE = 32; + // Maximum number of source segments whose training vectors are extracted + // concurrently. Each source is read by a single thread using its own View, + // so this is the number of remote-storage reads that can be in flight at + // once. Defaults to 16; overridable via the system property below for + // deployments backed by slower or faster storage. Clamped to >= 1. + private static final int IO_THREADS = Math.max(1, + Integer.getInteger("jvector.pq.retrain.io.threads", 16)); + + // Emit a progress log line every this many extracted training vectors, so + // a slow-but-progressing extraction can be distinguished from a hang. + private static final int PROGRESS_LOG_INTERVAL = 10_000; + private final List sources; private final List liveNodes; private final List numLiveNodesPerSource; @@ -80,16 +101,15 @@ public ProductQuantization retrain(VectorSimilarityFunction similarityFunction) List samples = sampleBalanced(ProductQuantization.MAX_PQ_TRAINING_SET_SIZE); // Sort by (source, node) so extractVectorsSequential reads each source's file - // in ascending order, enabling OS read-ahead instead of random page faults. + // in ascending order, enabling read-ahead instead of random page faults. samples.sort(Comparator.comparingInt((SampleRef r) -> r.source).thenComparingInt(r -> r.node)); log.info("Collected {} training samples", samples.size()); - // Extract vectors sequentially in sorted (source, node) order so disk reads are - // purely sequential and the OS read-ahead can cover them efficiently. We do this - // here rather than letting ProductQuantization.compute() drive the reads via its - // parallel stream, which would scatter page faults across a potentially very large - // file and cause I/O that scales with dataset size rather than sample count. + // Extract vectors up front so ProductQuantization.compute() itself performs no + // I/O. extractVectorsSequential reads each source ascending (read-ahead friendly) + // and reads distinct sources concurrently so the per-read latency of remote + // storage is hidden instead of serialized across thousands of samples. List> trainingVectors = extractVectorsSequential(samples); var ravv = new ListRandomAccessVectorValues(trainingVectors, dimension); @@ -198,22 +218,100 @@ private List sampleBalanced(int totalSamples) { } /** - * Reads sampled vectors in the order provided. The caller must pre-sort {@code samples} - * by (source, node) so reads within each source are ascending, letting the OS read-ahead - * cover them efficiently. Each source's view is opened once and reused for all its samples. + * Reads the sampled vectors into memory. + * + *

The caller pre-sorts {@code samples} by (source, node) so each source's + * sub-list is contiguous and ascending. Sources are then extracted + * concurrently: each source is processed by a single worker thread + * using its own {@link OnDiskGraphIndex.View} (and therefore its own + * {@code RandomAccessReader}), so concurrent reads are safe and up to + * {@link #IO_THREADS} remote-storage reads can be in flight at once. Within + * a source the reads stay ascending, preserving read-ahead friendliness. + * + *

This matters for graphs backed by remote storage (e.g. an object + * store): there each {@code getVectorInto} is a network round-trip with no + * OS read-ahead, so a single-threaded loop over thousands of samples + * serializes thousands of round-trips. Reading sources in parallel hides + * that latency. The returned list order is unspecified — irrelevant, since + * it only feeds {@link ProductQuantization#compute} as an unordered + * training set. */ private List> extractVectorsSequential(List samples) { - OnDiskGraphIndex.View[] views = new OnDiskGraphIndex.View[sources.size()]; - for (int s = 0; s < sources.size(); s++) { - views[s] = (OnDiskGraphIndex.View) sources.get(s).getView(); + if (samples.isEmpty()) { + return new ArrayList<>(); } - List> vectors = new ArrayList<>(samples.size()); - VectorFloat tmp = vectorTypeSupport.createFloatVector(dimension); + // Group the pre-sorted samples by source. The caller sorted by + // (source, node) so each source's sub-list is contiguous and ascending. + Map> bySource = new LinkedHashMap<>(); for (SampleRef ref : samples) { - views[ref.source].getVectorInto(ref.node, tmp, 0); - vectors.add(tmp.copy()); + bySource.computeIfAbsent(ref.source, k -> new ArrayList<>()).add(ref); + } + + int parallelism = Math.min(bySource.size(), IO_THREADS); + // Order is irrelevant for PQ codebook training, so a synchronized list + // collecting from all workers is sufficient. + List> vectors = + Collections.synchronizedList(new ArrayList<>(samples.size())); + AtomicInteger progress = new AtomicInteger(); + int totalSamples = samples.size(); + long startNanos = System.nanoTime(); + + ExecutorService pool = Executors.newFixedThreadPool(parallelism, r -> { + Thread t = new Thread(r, "pq-retrain-io"); + t.setDaemon(true); + return t; + }); + try { + List> futures = new ArrayList<>(bySource.size()); + for (Map.Entry> entry : bySource.entrySet()) { + final int source = entry.getKey(); + final List group = entry.getValue(); + futures.add(pool.submit(() -> { + // One View — and one RandomAccessReader — per task. Never + // shared across threads, so concurrent extraction is safe. + OnDiskGraphIndex.View view = + (OnDiskGraphIndex.View) sources.get(source).getView(); + try { + VectorFloat scratch = vectorTypeSupport.createFloatVector(dimension); + for (SampleRef ref : group) { + view.getVectorInto(ref.node, scratch, 0); + vectors.add(scratch.copy()); + int done = progress.incrementAndGet(); + if (done % PROGRESS_LOG_INTERVAL == 0) { + log.info("PQ retraining: extracted {}/{} training vectors", + done, totalSamples); + } + } + } finally { + try { + view.close(); + } catch (IOException ioe) { + log.warn("Failed to close source {} view during PQ retraining", + source, ioe); + } + } + })); + } + // Wait for completion; surface the first failure. + for (Future f : futures) { + try { + f.get(); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("PQ retraining vector extraction interrupted", ie); + } catch (ExecutionException ee) { + throw new RuntimeException("PQ retraining vector extraction failed", + ee.getCause()); + } + } + } finally { + pool.shutdownNow(); } + + log.info("PQ retraining: extracted {} training vectors from {} sources in {} ms ({} threads)", + vectors.size(), bySource.size(), + (System.nanoTime() - startNanos) / 1_000_000L, parallelism); return vectors; } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java index 410b96d0e..61756f090 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndexCompactor.java @@ -536,6 +536,101 @@ public void testCompact() throws Exception { searcher.close(); } + /** + * Builds a FusedPQ-enabled source graph from the given vectors and writes it + * to {@code name} under the test directory. Mirrors {@link #buildFusedPQ()} + * but is parameterised so a test can create an arbitrary number of sources. + */ + private Path buildFusedPQSource(List> vecs, String name) throws IOException { + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vecs, dimension); + ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true, UNWEIGHTED, simdExecutor, parallelExecutor); + PQVectors pqv = (PQVectors) pq.encodeAll(ravv, simdExecutor); + var bsp = BuildScoreProvider.pqBuildScoreProvider(similarityFunction, pqv); + var builder = new GraphIndexBuilder(bsp, dimension, 16, 100, 1.2f, 1.2f, false, true, simdExecutor, parallelExecutor); + var graph = builder.getGraph(); + + var outputPath = testDirectory.resolve(name); + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal))); + + var identityMapper = new OrdinalMapper.IdentityMapper(ravv.size() - 1); + var writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath); + writerBuilder.withMapper(identityMapper); + writerBuilder.with(new InlineVectors(dimension)); + writerBuilder.with(new FusedPQ(graph.maxDegree(), pq)); + var writer = writerBuilder.build(); + + for (var node = 0; node < ravv.size(); node++) { + var stateMap = new EnumMap(FeatureId.class); + stateMap.put(FeatureId.INLINE_VECTORS, writeSuppliers.get(FeatureId.INLINE_VECTORS).apply(node)); + writer.writeInline(node, stateMap); + builder.addGraphNode(node, ravv.getVector(node)); + } + builder.cleanup(); + + writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal)); + writer.write(writeSuppliers); + return outputPath; + } + + /** + * Exercises the parallel PQ-retraining path (issue #587): with many FusedPQ + * sources, {@code PQRetrainer.extractVectorsSequential} reads each source on + * its own thread/View. This test compacts more sources than the default + * {@code testCompact} uses, then verifies that every source's inline vectors + * survive the compaction exactly at their remapped ordinals — proving the + * concurrent extraction did not throw, deadlock, drop, or mis-assign vectors. + */ + @Test + public void testCompactManySourcesParallelRetrain() throws Exception { + final int sourceCount = 8; + final int nodesPerSource = 256; + + List graphs = new ArrayList<>(); + List rss = new ArrayList<>(); + List liveNodes = new ArrayList<>(); + List remappers = new ArrayList<>(); + // expected[globalOrdinal] -> original vector + List> expected = new ArrayList<>(); + + int globalOrdinal = 0; + for (int s = 0; s < sourceCount; s++) { + List> vecs = createRandomVectors(nodesPerSource, dimension); + Path path = buildFusedPQSource(vecs, "parallel_retrain_src_" + s); + + ReaderSupplier rs = ReaderSupplierFactory.open(path); + rss.add(rs); + graphs.add(OnDiskGraphIndex.load(rs)); + + Map map = new HashMap<>(nodesPerSource); + for (int i = 0; i < nodesPerSource; i++) { + map.put(i, globalOrdinal++); + expected.add(vecs.get(i)); + } + remappers.add(new OrdinalMapper.MapMapper(map)); + + FixedBitSet live = new FixedBitSet(nodesPerSource); + live.set(0, nodesPerSource); + liveNodes.add(live); + } + + var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null); + Path outPath = testDirectory.resolve("parallel_retrain_out"); + compactor.compact(outPath); + + ReaderSupplier rsOut = ReaderSupplierFactory.open(outPath); + OnDiskGraphIndex compacted = OnDiskGraphIndex.load(rsOut); + assertEquals("compacted graph must contain every source's nodes", + sourceCount * nodesPerSource, compacted.size(0)); + + var view = compacted.getView(); + VectorFloat buf = vectorTypeSupport.createFloatVector(dimension); + for (int ord = 0; ord < expected.size(); ord++) { + view.getVectorInto(ord, buf, 0); + assertVecEquals(expected.get(ord), buf, ord); + } + } + /** * Tests compaction with deleted nodes. * Verifies that deleted nodes are properly excluded from the compacted graph.