From cca8fdee1ac58d7f10277ea2ab9a32b4901b5818 Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Tue, 19 May 2026 14:52:52 +0200 Subject: [PATCH] issue #599: add readerSupplierFactory to PQRetrainer and OnDiskGraphIndexCompactor Allow callers to inject a Function into PQRetrainer and OnDiskGraphIndexCompactor so that PQ retraining vector extraction can use a pre-downloaded or locally-buffered copy of each source graph file instead of per-node block-cache round-trips. Changes: - OnDiskGraphIndex: add getView(ReaderSupplier) overload so a View can be opened against a caller-supplied reader rather than the default one - PQRetrainer: new 4-arg constructor accepting an optional readerSupplierFactory; when non-null, extractVectorsSequential opens each source View via getView(supplier) and closes the supplier after all vectors for that source are extracted - OnDiskGraphIndexCompactor: new 6-arg constructor forwarding the factory to PQRetrainer; existing 5-arg constructor passes null (no change to default behaviour) - TestPQRetrainerCustomReader: unit tests covering getView(supplier), factory invocation count, and backward-compatibility of the old constructor Co-Authored-By: Claude Sonnet 4.6 --- .../jvector/graph/disk/OnDiskGraphIndex.java | 17 ++ .../graph/disk/OnDiskGraphIndexCompactor.java | 39 ++- .../jvector/graph/disk/PQRetrainer.java | 69 ++++- .../disk/TestPQRetrainerCustomReader.java | 271 ++++++++++++++++++ 4 files changed, 392 insertions(+), 4 deletions(-) create mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 4270225c9..757fcac2f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -455,6 +455,23 @@ public View getView() { } } + /** + * Opens a {@link View} backed by a caller-supplied {@link ReaderSupplier} instead of the + * index's own internal reader. This allows callers that have already downloaded or buffered + * the index file locally (e.g. for bulk vector extraction during PQ retraining) to avoid + * routing reads through the block cache. + * + *

The caller is responsible for closing the returned View (which closes the underlying + * reader) and for separately closing the {@code supplier} when it is no longer needed. + * + * @param supplier the reader supplier to use; {@code supplier.get()} is called exactly once + * @return a new View backed by the provided reader + * @throws IOException if {@code supplier.get()} throws + */ + public View getView(ReaderSupplier supplier) throws IOException { + return new View(supplier.get()); + } + @Override public double getAverageDegree(int level) { var view = this.getView(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java index 07b96cc85..66cde90bb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java @@ -25,7 +25,9 @@ import java.nio.file.StandardOpenOption; import java.util.*; import java.util.concurrent.*; +import java.util.function.Function; import java.util.stream.IntStream; +import io.github.jbellis.jvector.disk.ReaderSupplier; import io.github.jbellis.jvector.graph.*; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; @@ -71,10 +73,19 @@ public final class OnDiskGraphIndexCompactor implements Accountable { private final ForkJoinPool executor; private final int taskWindowSize; private final VectorSimilarityFunction similarityFunction; + /** + * Optional factory for bulk sequential readers used during PQ retraining. + * When non-null, {@link #resolvePQFromSources} passes it to {@link PQRetrainer} + * so that training-vector extraction avoids per-node block-cache round-trips + * (issue #599 Option B). {@code null} uses the default block-cache path. + */ + private final Function readerSupplierFactory; /** * Constructs a new OnDiskGraphIndexCompactor to merge multiple graph indexes. - * Initializes thread pool, validates inputs, and prepares metadata for compaction. + * Uses the default block-cache path for PQ retraining vector extraction. + * + * @see #OnDiskGraphIndexCompactor(List, List, List, VectorSimilarityFunction, ForkJoinPool, Function) */ public OnDiskGraphIndexCompactor( List sources, @@ -82,6 +93,27 @@ public OnDiskGraphIndexCompactor( List remappers, VectorSimilarityFunction similarityFunction, ForkJoinPool executor) { + this(sources, liveNodes, remappers, similarityFunction, executor, null); + } + + /** + * Constructs a new OnDiskGraphIndexCompactor to merge multiple graph indexes, + * with an optional reader-supplier factory for bulk PQ retraining I/O. + * + * @param readerSupplierFactory when non-null, called once per source segment + * during PQ retraining to obtain a + * {@link ReaderSupplier} backed by a pre-downloaded + * or locally-buffered copy of the index file; avoids + * per-node block-cache round-trips (issue #599 + * Option B); pass {@code null} for the default path + */ + public OnDiskGraphIndexCompactor( + List sources, + List liveNodes, + List remappers, + VectorSimilarityFunction similarityFunction, + ForkJoinPool executor, + Function readerSupplierFactory) { checkBeforeCompact(sources, liveNodes, remappers); int threads = Runtime.getRuntime().availableProcessors(); @@ -112,6 +144,7 @@ public OnDiskGraphIndexCompactor( maxOrdinal = max(mapper.maxOrdinal(), maxOrdinal); } this.similarityFunction = similarityFunction; + this.readerSupplierFactory = readerSupplierFactory; } /** @@ -971,7 +1004,9 @@ private List computeLayerInfoFromSources() { * indexes. This ensures the PQ is optimized for the combined dataset. */ private ProductQuantization resolvePQFromSources(VectorSimilarityFunction similarityFunction) { - PQRetrainer retrainer = new PQRetrainer(sources, liveNodes, dimension); + // Pass the reader-supplier factory so PQRetrainer can open bulk sequential readers + // instead of per-node block-cache reads when remote storage is involved (issue #599). + PQRetrainer retrainer = new PQRetrainer(sources, liveNodes, dimension, readerSupplierFactory); return retrainer.retrain(similarityFunction); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java index 86a93c584..3538a7fe9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.graph.disk; +import io.github.jbellis.jvector.disk.ReaderSupplier; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; @@ -42,6 +43,7 @@ import java.util.concurrent.Future; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; /** * Handles Product Quantization retraining for graph index compaction. @@ -74,11 +76,44 @@ public class PQRetrainer { private final List numLiveNodesPerSource; private final int dimension; private final int numTotalNodes; + /** + * Optional factory that, given a source {@link OnDiskGraphIndex}, returns a + * {@link ReaderSupplier} backed by a pre-downloaded or locally-buffered copy of + * the index file. When non-null, {@link #extractVectorsSequential} uses it to + * open Views via {@link OnDiskGraphIndex#getView(ReaderSupplier)} instead of the + * default {@link OnDiskGraphIndex#getView()}, avoiding per-node block-cache reads + * over remote storage (issue #599 Option B). + * + *

The factory may throw {@link RuntimeException} (including + * {@link java.io.UncheckedIOException}) on failure; such exceptions surface as + * the existing {@code RuntimeException("PQ retraining vector extraction failed")} + * wrapper. + */ + private final Function readerSupplierFactory; + /** + * Constructs a {@code PQRetrainer} without a custom reader-supplier factory. + * Vector extraction uses the default {@link OnDiskGraphIndex#getView()}, which + * routes reads through the block cache. + */ public PQRetrainer(List sources, List liveNodes, int dimension) { + this(sources, liveNodes, dimension, null); + } + + /** + * Constructs a {@code PQRetrainer} with an optional reader-supplier factory. + * + * @param readerSupplierFactory when non-null, called once per source segment to + * obtain a {@link ReaderSupplier} for bulk sequential + * reads; the returned supplier is closed after all + * vectors for that source are extracted + */ + public PQRetrainer(List sources, List liveNodes, int dimension, + Function readerSupplierFactory) { this.sources = sources; this.liveNodes = liveNodes; this.dimension = dimension; + this.readerSupplierFactory = readerSupplierFactory; this.numLiveNodesPerSource = new ArrayList<>(sources.size()); int total = 0; @@ -270,8 +305,30 @@ private List> extractVectorsSequential(List samples) { futures.add(pool.submit(() -> { // One View — and one RandomAccessReader — per task. Never // shared across threads, so concurrent extraction is safe. - OnDiskGraphIndex.View view = - (OnDiskGraphIndex.View) sources.get(source).getView(); + // + // When a readerSupplierFactory is present, use it to open a View + // backed by a pre-downloaded or locally-buffered copy of the index + // file instead of the default block-cache reader. This eliminates + // per-node remote-storage round-trips (issue #599 Option B). + final OnDiskGraphIndex odg = sources.get(source); + final ReaderSupplier supplierForSource; + final OnDiskGraphIndex.View view; + if (readerSupplierFactory != null) { + supplierForSource = readerSupplierFactory.apply(odg); + try { + view = odg.getView(supplierForSource); + } catch (IOException e) { + try { + supplierForSource.close(); + } catch (IOException suppressed) { + e.addSuppressed(suppressed); + } + throw new java.io.UncheckedIOException(e); + } + } else { + supplierForSource = null; + view = (OnDiskGraphIndex.View) odg.getView(); + } try { VectorFloat scratch = vectorTypeSupport.createFloatVector(dimension); for (SampleRef ref : group) { @@ -290,6 +347,14 @@ private List> extractVectorsSequential(List samples) { log.warn("Failed to close source {} view during PQ retraining", source, ioe); } + if (supplierForSource != null) { + try { + supplierForSource.close(); + } catch (IOException ioe) { + log.warn("Failed to close bulk reader supplier for source {} during PQ retraining", + source, ioe); + } + } } })); } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java new file mode 100644 index 000000000..5b6b2ee7e --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java @@ -0,0 +1,271 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntFunction; + +import static io.github.jbellis.jvector.TestUtil.createRandomVectors; +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Verifies the custom-reader-supplier API introduced in issue #599 (Option B): + * + *

    + *
  1. {@link OnDiskGraphIndex#getView(ReaderSupplier)} opens a View backed by the + * caller-supplied reader and reads vectors correctly.
  2. + *
  3. When a non-null factory is passed to + * {@link PQRetrainer#PQRetrainer(List, List, int, java.util.function.Function)}, + * the factory is invoked once per source segment during + * {@link PQRetrainer#retrain}, and the returned PQ codebook is valid.
  4. + *
+ */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestPQRetrainerCustomReader extends RandomizedTest { + + private static final VectorTypeSupport vectorTypeSupport = + VectorizationProvider.getInstance().getVectorTypeSupport(); + + private final ForkJoinPool executor = ForkJoinPool.commonPool(); + + private Path testDirectory; + private static final int DIMENSION = 32; + private static final int VECTORS_PER_SOURCE = 256; + private static final VectorSimilarityFunction SIMILARITY = VectorSimilarityFunction.COSINE; + + @Before + public void setUp() throws IOException { + testDirectory = Files.createTempDirectory("TestPQRetrainerCustomReader"); + } + + @After + public void tearDown() throws IOException { + // Delete temp files created during the test. + if (testDirectory != null) { + try (var stream = Files.walk(testDirectory)) { + stream.sorted(java.util.Comparator.reverseOrder()) + .map(java.nio.file.Path::toFile) + .forEach(java.io.File::delete); + } + } + } + + // ------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------- + + /** + * Builds a small FusedPQ-enabled {@link OnDiskGraphIndex} and writes it to + * {@code outputPath}. + * + * @return the written path (same as {@code outputPath}) + */ + private Path buildFusedPQSource(List> vecs, Path outputPath) throws IOException { + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vecs, DIMENSION); + ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true, UNWEIGHTED, executor, executor); + PQVectors pqv = (PQVectors) pq.encodeAll(ravv, executor); + var bsp = BuildScoreProvider.pqBuildScoreProvider(SIMILARITY, pqv); + var builder = new GraphIndexBuilder(bsp, DIMENSION, 16, 100, 1.2f, 1.2f, false, true, executor, executor); + var graph = builder.getGraph(); + + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal))); + + var identityMapper = new OrdinalMapper.IdentityMapper(ravv.size() - 1); + var writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath); + writerBuilder.withMapper(identityMapper); + writerBuilder.with(new InlineVectors(DIMENSION)); + writerBuilder.with(new FusedPQ(graph.maxDegree(), pq)); + var writer = writerBuilder.build(); + + for (int node = 0; node < ravv.size(); node++) { + var stateMap = new EnumMap(FeatureId.class); + stateMap.put(FeatureId.INLINE_VECTORS, writeSuppliers.get(FeatureId.INLINE_VECTORS).apply(node)); + writer.writeInline(node, stateMap); + builder.addGraphNode(node, ravv.getVector(node)); + } + builder.cleanup(); + writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal)); + writer.write(writeSuppliers); + return outputPath; + } + + // ------------------------------------------------------------------------- + // Tests + // ------------------------------------------------------------------------- + + /** + * Verifies that {@link OnDiskGraphIndex#getView(ReaderSupplier)} opens a view + * backed by the caller-supplied reader and returns the same vectors as the + * default {@link OnDiskGraphIndex#getView()}. + */ + @Test + public void getViewWithCustomSupplierReadsVectorsCorrectly() throws IOException { + List> vecs = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION); + Path p = buildFusedPQSource(vecs, testDirectory.resolve("source.idx")); + + try (ReaderSupplier defaultSupplier = ReaderSupplierFactory.open(p); + ReaderSupplier customSupplier = ReaderSupplierFactory.open(p)) { + OnDiskGraphIndex odg = OnDiskGraphIndex.load(defaultSupplier); + + // Read the first 5 nodes via the default path. + VectorFloat expected = vectorTypeSupport.createFloatVector(DIMENSION); + try (OnDiskGraphIndex.View defaultView = odg.getView()) { + defaultView.getVectorInto(0, expected, 0); + } + + // Read the same node via the custom-supplier path. + VectorFloat actual = vectorTypeSupport.createFloatVector(DIMENSION); + try (OnDiskGraphIndex.View customView = odg.getView(customSupplier)) { + customView.getVectorInto(0, actual, 0); + } + + // Vectors must be byte-identical. + for (int d = 0; d < DIMENSION; d++) { + assertEquals("dimension " + d, expected.get(d), actual.get(d), 0f); + } + } + } + + /** + * Verifies that when a non-null factory is passed to {@link PQRetrainer}, the + * factory is invoked exactly once per source segment during + * {@link PQRetrainer#retrain}, and the resulting PQ codebook is non-null and + * has the expected structure. + */ + @Test + public void pqRetrainerInvokesCustomFactoryOncePerSource() throws IOException { + final int numSources = 3; + + List sources = new ArrayList<>(); + List liveSets = new ArrayList<>(); + List openSuppliers = new ArrayList<>(); + + for (int i = 0; i < numSources; i++) { + List> vecs = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION); + Path p = buildFusedPQSource(vecs, testDirectory.resolve("src_" + i + ".idx")); + ReaderSupplier rs = ReaderSupplierFactory.open(p); + openSuppliers.add(rs); + OnDiskGraphIndex odg = OnDiskGraphIndex.load(rs); + sources.add(odg); + + FixedBitSet live = new FixedBitSet(VECTORS_PER_SOURCE); + live.set(0, VECTORS_PER_SOURCE); // all nodes live + liveSets.add(live); + } + + AtomicInteger factoryCalls = new AtomicInteger(0); + + // Build a tracking factory that counts invocations and wraps the file reader. + java.util.function.Function trackingFactory = odg -> { + int idx = sources.indexOf(odg); + assertTrue("factory called with unknown source", idx >= 0); + factoryCalls.incrementAndGet(); + // Open a fresh reader for the source file (same path as the source). + try { + return ReaderSupplierFactory.open(testDirectory.resolve("src_" + idx + ".idx")); + } catch (IOException e) { + throw new java.io.UncheckedIOException(e); + } + }; + + PQRetrainer retrainer = new PQRetrainer(sources, liveSets, DIMENSION, trackingFactory); + ProductQuantization pq = retrainer.retrain(SIMILARITY); + + // The factory must have been called once per source. + assertEquals("factory must be called once per source", numSources, factoryCalls.get()); + + // The resulting PQ codebook must be valid. + assertNotNull("PQ codebook must not be null", pq); + assertTrue("PQ must have at least one subspace", pq.getSubspaceCount() > 0); + assertEquals("PQ cluster count must be 256", 256, pq.getClusterCount()); + + // Clean up open suppliers. + for (ReaderSupplier rs : openSuppliers) { + rs.close(); + } + } + + /** + * Verifies that the default (3-arg) constructor delegates with a null factory, + * i.e. the original code path is unaffected. + */ + @Test + public void defaultConstructorUsesNullFactory() throws IOException { + List> vecs0 = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION); + List> vecs1 = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION); + Path p0 = buildFusedPQSource(vecs0, testDirectory.resolve("default0.idx")); + Path p1 = buildFusedPQSource(vecs1, testDirectory.resolve("default1.idx")); + + List sources = new ArrayList<>(); + List liveSets = new ArrayList<>(); + List openSuppliers = new ArrayList<>(); + + for (Path p : new Path[]{p0, p1}) { + ReaderSupplier rs = ReaderSupplierFactory.open(p); + openSuppliers.add(rs); + sources.add(OnDiskGraphIndex.load(rs)); + FixedBitSet live = new FixedBitSet(VECTORS_PER_SOURCE); + live.set(0, VECTORS_PER_SOURCE); + liveSets.add(live); + } + + // 3-arg constructor — null factory, uses default getView() path. + PQRetrainer retrainer = new PQRetrainer(sources, liveSets, DIMENSION); + ProductQuantization pq = retrainer.retrain(SIMILARITY); + assertNotNull("default-path PQ codebook must not be null", pq); + + for (ReaderSupplier rs : openSuppliers) { + rs.close(); + } + } +}