diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java
index 4270225c9..757fcac2f 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java
@@ -455,6 +455,23 @@ public View getView() {
}
}
+ /**
+ * Opens a {@link View} backed by a caller-supplied {@link ReaderSupplier} instead of the
+ * index's own internal reader. This allows callers that have already downloaded or buffered
+ * the index file locally (e.g. for bulk vector extraction during PQ retraining) to avoid
+ * routing reads through the block cache.
+ *
+ *
The caller is responsible for closing the returned View (which closes the underlying
+ * reader) and for separately closing the {@code supplier} when it is no longer needed.
+ *
+ * @param supplier the reader supplier to use; {@code supplier.get()} is called exactly once
+ * @return a new View backed by the provided reader
+ * @throws IOException if {@code supplier.get()} throws
+ */
+ public View getView(ReaderSupplier supplier) throws IOException {
+ return new View(supplier.get());
+ }
+
@Override
public double getAverageDegree(int level) {
var view = this.getView();
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java
index 07b96cc85..66cde90bb 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java
@@ -25,7 +25,9 @@
import java.nio.file.StandardOpenOption;
import java.util.*;
import java.util.concurrent.*;
+import java.util.function.Function;
import java.util.stream.IntStream;
+import io.github.jbellis.jvector.disk.ReaderSupplier;
import io.github.jbellis.jvector.graph.*;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
@@ -71,10 +73,19 @@ public final class OnDiskGraphIndexCompactor implements Accountable {
private final ForkJoinPool executor;
private final int taskWindowSize;
private final VectorSimilarityFunction similarityFunction;
+ /**
+ * Optional factory for bulk sequential readers used during PQ retraining.
+ * When non-null, {@link #resolvePQFromSources} passes it to {@link PQRetrainer}
+ * so that training-vector extraction avoids per-node block-cache round-trips
+ * (issue #599 Option B). {@code null} uses the default block-cache path.
+ */
+ private final Function readerSupplierFactory;
/**
* Constructs a new OnDiskGraphIndexCompactor to merge multiple graph indexes.
- * Initializes thread pool, validates inputs, and prepares metadata for compaction.
+ * Uses the default block-cache path for PQ retraining vector extraction.
+ *
+ * @see #OnDiskGraphIndexCompactor(List, List, List, VectorSimilarityFunction, ForkJoinPool, Function)
*/
public OnDiskGraphIndexCompactor(
List sources,
@@ -82,6 +93,27 @@ public OnDiskGraphIndexCompactor(
List remappers,
VectorSimilarityFunction similarityFunction,
ForkJoinPool executor) {
+ this(sources, liveNodes, remappers, similarityFunction, executor, null);
+ }
+
+ /**
+ * Constructs a new OnDiskGraphIndexCompactor to merge multiple graph indexes,
+ * with an optional reader-supplier factory for bulk PQ retraining I/O.
+ *
+ * @param readerSupplierFactory when non-null, called once per source segment
+ * during PQ retraining to obtain a
+ * {@link ReaderSupplier} backed by a pre-downloaded
+ * or locally-buffered copy of the index file; avoids
+ * per-node block-cache round-trips (issue #599
+ * Option B); pass {@code null} for the default path
+ */
+ public OnDiskGraphIndexCompactor(
+ List sources,
+ List liveNodes,
+ List remappers,
+ VectorSimilarityFunction similarityFunction,
+ ForkJoinPool executor,
+ Function readerSupplierFactory) {
checkBeforeCompact(sources, liveNodes, remappers);
int threads = Runtime.getRuntime().availableProcessors();
@@ -112,6 +144,7 @@ public OnDiskGraphIndexCompactor(
maxOrdinal = max(mapper.maxOrdinal(), maxOrdinal);
}
this.similarityFunction = similarityFunction;
+ this.readerSupplierFactory = readerSupplierFactory;
}
/**
@@ -971,7 +1004,9 @@ private List computeLayerInfoFromSources() {
* indexes. This ensures the PQ is optimized for the combined dataset.
*/
private ProductQuantization resolvePQFromSources(VectorSimilarityFunction similarityFunction) {
- PQRetrainer retrainer = new PQRetrainer(sources, liveNodes, dimension);
+ // Pass the reader-supplier factory so PQRetrainer can open bulk sequential readers
+ // instead of per-node block-cache reads when remote storage is involved (issue #599).
+ PQRetrainer retrainer = new PQRetrainer(sources, liveNodes, dimension, readerSupplierFactory);
return retrainer.retrain(similarityFunction);
}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java
index 86a93c584..3538a7fe9 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/PQRetrainer.java
@@ -16,6 +16,7 @@
package io.github.jbellis.jvector.graph.disk;
+import io.github.jbellis.jvector.disk.ReaderSupplier;
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
@@ -42,6 +43,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
/**
* Handles Product Quantization retraining for graph index compaction.
@@ -74,11 +76,44 @@ public class PQRetrainer {
private final List numLiveNodesPerSource;
private final int dimension;
private final int numTotalNodes;
+ /**
+ * Optional factory that, given a source {@link OnDiskGraphIndex}, returns a
+ * {@link ReaderSupplier} backed by a pre-downloaded or locally-buffered copy of
+ * the index file. When non-null, {@link #extractVectorsSequential} uses it to
+ * open Views via {@link OnDiskGraphIndex#getView(ReaderSupplier)} instead of the
+ * default {@link OnDiskGraphIndex#getView()}, avoiding per-node block-cache reads
+ * over remote storage (issue #599 Option B).
+ *
+ * The factory may throw {@link RuntimeException} (including
+ * {@link java.io.UncheckedIOException}) on failure; such exceptions surface as
+ * the existing {@code RuntimeException("PQ retraining vector extraction failed")}
+ * wrapper.
+ */
+ private final Function readerSupplierFactory;
+ /**
+ * Constructs a {@code PQRetrainer} without a custom reader-supplier factory.
+ * Vector extraction uses the default {@link OnDiskGraphIndex#getView()}, which
+ * routes reads through the block cache.
+ */
public PQRetrainer(List sources, List liveNodes, int dimension) {
+ this(sources, liveNodes, dimension, null);
+ }
+
+ /**
+ * Constructs a {@code PQRetrainer} with an optional reader-supplier factory.
+ *
+ * @param readerSupplierFactory when non-null, called once per source segment to
+ * obtain a {@link ReaderSupplier} for bulk sequential
+ * reads; the returned supplier is closed after all
+ * vectors for that source are extracted
+ */
+ public PQRetrainer(List sources, List liveNodes, int dimension,
+ Function readerSupplierFactory) {
this.sources = sources;
this.liveNodes = liveNodes;
this.dimension = dimension;
+ this.readerSupplierFactory = readerSupplierFactory;
this.numLiveNodesPerSource = new ArrayList<>(sources.size());
int total = 0;
@@ -270,8 +305,30 @@ private List> extractVectorsSequential(List samples) {
futures.add(pool.submit(() -> {
// One View — and one RandomAccessReader — per task. Never
// shared across threads, so concurrent extraction is safe.
- OnDiskGraphIndex.View view =
- (OnDiskGraphIndex.View) sources.get(source).getView();
+ //
+ // When a readerSupplierFactory is present, use it to open a View
+ // backed by a pre-downloaded or locally-buffered copy of the index
+ // file instead of the default block-cache reader. This eliminates
+ // per-node remote-storage round-trips (issue #599 Option B).
+ final OnDiskGraphIndex odg = sources.get(source);
+ final ReaderSupplier supplierForSource;
+ final OnDiskGraphIndex.View view;
+ if (readerSupplierFactory != null) {
+ supplierForSource = readerSupplierFactory.apply(odg);
+ try {
+ view = odg.getView(supplierForSource);
+ } catch (IOException e) {
+ try {
+ supplierForSource.close();
+ } catch (IOException suppressed) {
+ e.addSuppressed(suppressed);
+ }
+ throw new java.io.UncheckedIOException(e);
+ }
+ } else {
+ supplierForSource = null;
+ view = (OnDiskGraphIndex.View) odg.getView();
+ }
try {
VectorFloat> scratch = vectorTypeSupport.createFloatVector(dimension);
for (SampleRef ref : group) {
@@ -290,6 +347,14 @@ private List> extractVectorsSequential(List samples) {
log.warn("Failed to close source {} view during PQ retraining",
source, ioe);
}
+ if (supplierForSource != null) {
+ try {
+ supplierForSource.close();
+ } catch (IOException ioe) {
+ log.warn("Failed to close bulk reader supplier for source {} during PQ retraining",
+ source, ioe);
+ }
+ }
}
}));
}
diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java
new file mode 100644
index 000000000..5b6b2ee7e
--- /dev/null
+++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestPQRetrainerCustomReader.java
@@ -0,0 +1,271 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.github.jbellis.jvector.graph.disk;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
+import io.github.jbellis.jvector.disk.RandomAccessReader;
+import io.github.jbellis.jvector.disk.ReaderSupplier;
+import io.github.jbellis.jvector.disk.ReaderSupplierFactory;
+import io.github.jbellis.jvector.graph.GraphIndexBuilder;
+import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
+import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
+import io.github.jbellis.jvector.graph.disk.feature.Feature;
+import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
+import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
+import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
+import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
+import io.github.jbellis.jvector.quantization.PQVectors;
+import io.github.jbellis.jvector.quantization.ProductQuantization;
+import io.github.jbellis.jvector.util.FixedBitSet;
+import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
+import io.github.jbellis.jvector.vector.VectorizationProvider;
+import io.github.jbellis.jvector.vector.types.VectorFloat;
+import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntFunction;
+
+import static io.github.jbellis.jvector.TestUtil.createRandomVectors;
+import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Verifies the custom-reader-supplier API introduced in issue #599 (Option B):
+ *
+ *
+ * - {@link OnDiskGraphIndex#getView(ReaderSupplier)} opens a View backed by the
+ * caller-supplied reader and reads vectors correctly.
+ * - When a non-null factory is passed to
+ * {@link PQRetrainer#PQRetrainer(List, List, int, java.util.function.Function)},
+ * the factory is invoked once per source segment during
+ * {@link PQRetrainer#retrain}, and the returned PQ codebook is valid.
+ *
+ */
+@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
+public class TestPQRetrainerCustomReader extends RandomizedTest {
+
+ private static final VectorTypeSupport vectorTypeSupport =
+ VectorizationProvider.getInstance().getVectorTypeSupport();
+
+ private final ForkJoinPool executor = ForkJoinPool.commonPool();
+
+ private Path testDirectory;
+ private static final int DIMENSION = 32;
+ private static final int VECTORS_PER_SOURCE = 256;
+ private static final VectorSimilarityFunction SIMILARITY = VectorSimilarityFunction.COSINE;
+
+ @Before
+ public void setUp() throws IOException {
+ testDirectory = Files.createTempDirectory("TestPQRetrainerCustomReader");
+ }
+
+ @After
+ public void tearDown() throws IOException {
+ // Delete temp files created during the test.
+ if (testDirectory != null) {
+ try (var stream = Files.walk(testDirectory)) {
+ stream.sorted(java.util.Comparator.reverseOrder())
+ .map(java.nio.file.Path::toFile)
+ .forEach(java.io.File::delete);
+ }
+ }
+ }
+
+ // -------------------------------------------------------------------------
+ // Helpers
+ // -------------------------------------------------------------------------
+
+ /**
+ * Builds a small FusedPQ-enabled {@link OnDiskGraphIndex} and writes it to
+ * {@code outputPath}.
+ *
+ * @return the written path (same as {@code outputPath})
+ */
+ private Path buildFusedPQSource(List> vecs, Path outputPath) throws IOException {
+ RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vecs, DIMENSION);
+ ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true, UNWEIGHTED, executor, executor);
+ PQVectors pqv = (PQVectors) pq.encodeAll(ravv, executor);
+ var bsp = BuildScoreProvider.pqBuildScoreProvider(SIMILARITY, pqv);
+ var builder = new GraphIndexBuilder(bsp, DIMENSION, 16, 100, 1.2f, 1.2f, false, true, executor, executor);
+ var graph = builder.getGraph();
+
+ Map> writeSuppliers = new EnumMap<>(FeatureId.class);
+ writeSuppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal)));
+
+ var identityMapper = new OrdinalMapper.IdentityMapper(ravv.size() - 1);
+ var writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath);
+ writerBuilder.withMapper(identityMapper);
+ writerBuilder.with(new InlineVectors(DIMENSION));
+ writerBuilder.with(new FusedPQ(graph.maxDegree(), pq));
+ var writer = writerBuilder.build();
+
+ for (int node = 0; node < ravv.size(); node++) {
+ var stateMap = new EnumMap(FeatureId.class);
+ stateMap.put(FeatureId.INLINE_VECTORS, writeSuppliers.get(FeatureId.INLINE_VECTORS).apply(node));
+ writer.writeInline(node, stateMap);
+ builder.addGraphNode(node, ravv.getVector(node));
+ }
+ builder.cleanup();
+ writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal));
+ writer.write(writeSuppliers);
+ return outputPath;
+ }
+
+ // -------------------------------------------------------------------------
+ // Tests
+ // -------------------------------------------------------------------------
+
+ /**
+ * Verifies that {@link OnDiskGraphIndex#getView(ReaderSupplier)} opens a view
+ * backed by the caller-supplied reader and returns the same vectors as the
+ * default {@link OnDiskGraphIndex#getView()}.
+ */
+ @Test
+ public void getViewWithCustomSupplierReadsVectorsCorrectly() throws IOException {
+ List> vecs = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION);
+ Path p = buildFusedPQSource(vecs, testDirectory.resolve("source.idx"));
+
+ try (ReaderSupplier defaultSupplier = ReaderSupplierFactory.open(p);
+ ReaderSupplier customSupplier = ReaderSupplierFactory.open(p)) {
+ OnDiskGraphIndex odg = OnDiskGraphIndex.load(defaultSupplier);
+
+ // Read the first 5 nodes via the default path.
+ VectorFloat> expected = vectorTypeSupport.createFloatVector(DIMENSION);
+ try (OnDiskGraphIndex.View defaultView = odg.getView()) {
+ defaultView.getVectorInto(0, expected, 0);
+ }
+
+ // Read the same node via the custom-supplier path.
+ VectorFloat> actual = vectorTypeSupport.createFloatVector(DIMENSION);
+ try (OnDiskGraphIndex.View customView = odg.getView(customSupplier)) {
+ customView.getVectorInto(0, actual, 0);
+ }
+
+ // Vectors must be byte-identical.
+ for (int d = 0; d < DIMENSION; d++) {
+ assertEquals("dimension " + d, expected.get(d), actual.get(d), 0f);
+ }
+ }
+ }
+
+ /**
+ * Verifies that when a non-null factory is passed to {@link PQRetrainer}, the
+ * factory is invoked exactly once per source segment during
+ * {@link PQRetrainer#retrain}, and the resulting PQ codebook is non-null and
+ * has the expected structure.
+ */
+ @Test
+ public void pqRetrainerInvokesCustomFactoryOncePerSource() throws IOException {
+ final int numSources = 3;
+
+ List sources = new ArrayList<>();
+ List liveSets = new ArrayList<>();
+ List openSuppliers = new ArrayList<>();
+
+ for (int i = 0; i < numSources; i++) {
+ List> vecs = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION);
+ Path p = buildFusedPQSource(vecs, testDirectory.resolve("src_" + i + ".idx"));
+ ReaderSupplier rs = ReaderSupplierFactory.open(p);
+ openSuppliers.add(rs);
+ OnDiskGraphIndex odg = OnDiskGraphIndex.load(rs);
+ sources.add(odg);
+
+ FixedBitSet live = new FixedBitSet(VECTORS_PER_SOURCE);
+ live.set(0, VECTORS_PER_SOURCE); // all nodes live
+ liveSets.add(live);
+ }
+
+ AtomicInteger factoryCalls = new AtomicInteger(0);
+
+ // Build a tracking factory that counts invocations and wraps the file reader.
+ java.util.function.Function trackingFactory = odg -> {
+ int idx = sources.indexOf(odg);
+ assertTrue("factory called with unknown source", idx >= 0);
+ factoryCalls.incrementAndGet();
+ // Open a fresh reader for the source file (same path as the source).
+ try {
+ return ReaderSupplierFactory.open(testDirectory.resolve("src_" + idx + ".idx"));
+ } catch (IOException e) {
+ throw new java.io.UncheckedIOException(e);
+ }
+ };
+
+ PQRetrainer retrainer = new PQRetrainer(sources, liveSets, DIMENSION, trackingFactory);
+ ProductQuantization pq = retrainer.retrain(SIMILARITY);
+
+ // The factory must have been called once per source.
+ assertEquals("factory must be called once per source", numSources, factoryCalls.get());
+
+ // The resulting PQ codebook must be valid.
+ assertNotNull("PQ codebook must not be null", pq);
+ assertTrue("PQ must have at least one subspace", pq.getSubspaceCount() > 0);
+ assertEquals("PQ cluster count must be 256", 256, pq.getClusterCount());
+
+ // Clean up open suppliers.
+ for (ReaderSupplier rs : openSuppliers) {
+ rs.close();
+ }
+ }
+
+ /**
+ * Verifies that the default (3-arg) constructor delegates with a null factory,
+ * i.e. the original code path is unaffected.
+ */
+ @Test
+ public void defaultConstructorUsesNullFactory() throws IOException {
+ List> vecs0 = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION);
+ List> vecs1 = createRandomVectors(VECTORS_PER_SOURCE, DIMENSION);
+ Path p0 = buildFusedPQSource(vecs0, testDirectory.resolve("default0.idx"));
+ Path p1 = buildFusedPQSource(vecs1, testDirectory.resolve("default1.idx"));
+
+ List sources = new ArrayList<>();
+ List liveSets = new ArrayList<>();
+ List openSuppliers = new ArrayList<>();
+
+ for (Path p : new Path[]{p0, p1}) {
+ ReaderSupplier rs = ReaderSupplierFactory.open(p);
+ openSuppliers.add(rs);
+ sources.add(OnDiskGraphIndex.load(rs));
+ FixedBitSet live = new FixedBitSet(VECTORS_PER_SOURCE);
+ live.set(0, VECTORS_PER_SOURCE);
+ liveSets.add(live);
+ }
+
+ // 3-arg constructor — null factory, uses default getView() path.
+ PQRetrainer retrainer = new PQRetrainer(sources, liveSets, DIMENSION);
+ ProductQuantization pq = retrainer.retrain(SIMILARITY);
+ assertNotNull("default-path PQ codebook must not be null", pq);
+
+ for (ReaderSupplier rs : openSuppliers) {
+ rs.close();
+ }
+ }
+}