Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,19 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;

/**
* Handles Product Quantization retraining for graph index compaction.
Expand All @@ -48,6 +57,18 @@ public class PQRetrainer {
// avoiding the random I/O that would happen with per-node random sampling.
private static final int SAMPLE_CHUNK_SIZE = 32;

// Maximum number of source segments whose training vectors are extracted
// concurrently. Each source is read by a single thread using its own View,
// so this is the number of remote-storage reads that can be in flight at
// once. Defaults to 16; overridable via the system property below for
// deployments backed by slower or faster storage. Clamped to >= 1.
private static final int IO_THREADS = Math.max(1,
Integer.getInteger("jvector.pq.retrain.io.threads", 16));

// Emit a progress log line every this many extracted training vectors, so
// a slow-but-progressing extraction can be distinguished from a hang.
private static final int PROGRESS_LOG_INTERVAL = 10_000;

private final List<OnDiskGraphIndex> sources;
private final List<FixedBitSet> liveNodes;
private final List<Integer> numLiveNodesPerSource;
Expand Down Expand Up @@ -80,16 +101,15 @@ public ProductQuantization retrain(VectorSimilarityFunction similarityFunction)
List<SampleRef> samples = sampleBalanced(ProductQuantization.MAX_PQ_TRAINING_SET_SIZE);

// Sort by (source, node) so extractVectorsSequential reads each source's file
// in ascending order, enabling OS read-ahead instead of random page faults.
// in ascending order, enabling read-ahead instead of random page faults.
samples.sort(Comparator.comparingInt((SampleRef r) -> r.source).thenComparingInt(r -> r.node));

log.info("Collected {} training samples", samples.size());

// Extract vectors sequentially in sorted (source, node) order so disk reads are
// purely sequential and the OS read-ahead can cover them efficiently. We do this
// here rather than letting ProductQuantization.compute() drive the reads via its
// parallel stream, which would scatter page faults across a potentially very large
// file and cause I/O that scales with dataset size rather than sample count.
// Extract vectors up front so ProductQuantization.compute() itself performs no
// I/O. extractVectorsSequential reads each source ascending (read-ahead friendly)
// and reads distinct sources concurrently so the per-read latency of remote
// storage is hidden instead of serialized across thousands of samples.
List<VectorFloat<?>> trainingVectors = extractVectorsSequential(samples);
var ravv = new ListRandomAccessVectorValues(trainingVectors, dimension);

Expand Down Expand Up @@ -198,22 +218,100 @@ private List<SampleRef> sampleBalanced(int totalSamples) {
}

/**
* Reads sampled vectors in the order provided. The caller must pre-sort {@code samples}
* by (source, node) so reads within each source are ascending, letting the OS read-ahead
* cover them efficiently. Each source's view is opened once and reused for all its samples.
* Reads the sampled vectors into memory.
*
* <p>The caller pre-sorts {@code samples} by (source, node) so each source's
* sub-list is contiguous and ascending. Sources are then extracted
* <em>concurrently</em>: each source is processed by a single worker thread
* using its own {@link OnDiskGraphIndex.View} (and therefore its own
* {@code RandomAccessReader}), so concurrent reads are safe and up to
* {@link #IO_THREADS} remote-storage reads can be in flight at once. Within
* a source the reads stay ascending, preserving read-ahead friendliness.
*
* <p>This matters for graphs backed by remote storage (e.g. an object
* store): there each {@code getVectorInto} is a network round-trip with no
* OS read-ahead, so a single-threaded loop over thousands of samples
* serializes thousands of round-trips. Reading sources in parallel hides
* that latency. The returned list order is unspecified — irrelevant, since
* it only feeds {@link ProductQuantization#compute} as an unordered
* training set.
*/
private List<VectorFloat<?>> extractVectorsSequential(List<SampleRef> samples) {
OnDiskGraphIndex.View[] views = new OnDiskGraphIndex.View[sources.size()];
for (int s = 0; s < sources.size(); s++) {
views[s] = (OnDiskGraphIndex.View) sources.get(s).getView();
if (samples.isEmpty()) {
return new ArrayList<>();
}

List<VectorFloat<?>> vectors = new ArrayList<>(samples.size());
VectorFloat<?> tmp = vectorTypeSupport.createFloatVector(dimension);
// Group the pre-sorted samples by source. The caller sorted by
// (source, node) so each source's sub-list is contiguous and ascending.
Map<Integer, List<SampleRef>> bySource = new LinkedHashMap<>();
for (SampleRef ref : samples) {
views[ref.source].getVectorInto(ref.node, tmp, 0);
vectors.add(tmp.copy());
bySource.computeIfAbsent(ref.source, k -> new ArrayList<>()).add(ref);
}

int parallelism = Math.min(bySource.size(), IO_THREADS);
// Order is irrelevant for PQ codebook training, so a synchronized list
// collecting from all workers is sufficient.
List<VectorFloat<?>> vectors =
Collections.synchronizedList(new ArrayList<>(samples.size()));
AtomicInteger progress = new AtomicInteger();
int totalSamples = samples.size();
long startNanos = System.nanoTime();

ExecutorService pool = Executors.newFixedThreadPool(parallelism, r -> {
Thread t = new Thread(r, "pq-retrain-io");
t.setDaemon(true);
return t;
});
try {
List<Future<?>> futures = new ArrayList<>(bySource.size());
for (Map.Entry<Integer, List<SampleRef>> entry : bySource.entrySet()) {
final int source = entry.getKey();
final List<SampleRef> group = entry.getValue();
futures.add(pool.submit(() -> {
// One View — and one RandomAccessReader — per task. Never
// shared across threads, so concurrent extraction is safe.
OnDiskGraphIndex.View view =
(OnDiskGraphIndex.View) sources.get(source).getView();
try {
VectorFloat<?> scratch = vectorTypeSupport.createFloatVector(dimension);
for (SampleRef ref : group) {
view.getVectorInto(ref.node, scratch, 0);
vectors.add(scratch.copy());
int done = progress.incrementAndGet();
if (done % PROGRESS_LOG_INTERVAL == 0) {
log.info("PQ retraining: extracted {}/{} training vectors",
done, totalSamples);
}
}
} finally {
try {
view.close();
} catch (IOException ioe) {
log.warn("Failed to close source {} view during PQ retraining",
source, ioe);
}
}
}));
}
// Wait for completion; surface the first failure.
for (Future<?> f : futures) {
try {
f.get();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("PQ retraining vector extraction interrupted", ie);
} catch (ExecutionException ee) {
throw new RuntimeException("PQ retraining vector extraction failed",
ee.getCause());
}
}
} finally {
pool.shutdownNow();
}

log.info("PQ retraining: extracted {} training vectors from {} sources in {} ms ({} threads)",
vectors.size(), bySource.size(),
(System.nanoTime() - startNanos) / 1_000_000L, parallelism);
return vectors;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,101 @@ public void testCompact() throws Exception {
searcher.close();
}

/**
* Builds a FusedPQ-enabled source graph from the given vectors and writes it
* to {@code name} under the test directory. Mirrors {@link #buildFusedPQ()}
* but is parameterised so a test can create an arbitrary number of sources.
*/
private Path buildFusedPQSource(List<VectorFloat<?>> vecs, String name) throws IOException {
RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vecs, dimension);
ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true, UNWEIGHTED, simdExecutor, parallelExecutor);
PQVectors pqv = (PQVectors) pq.encodeAll(ravv, simdExecutor);
var bsp = BuildScoreProvider.pqBuildScoreProvider(similarityFunction, pqv);
var builder = new GraphIndexBuilder(bsp, dimension, 16, 100, 1.2f, 1.2f, false, true, simdExecutor, parallelExecutor);
var graph = builder.getGraph();

var outputPath = testDirectory.resolve(name);
Map<FeatureId, IntFunction<Feature.State>> writeSuppliers = new EnumMap<>(FeatureId.class);
writeSuppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal)));

var identityMapper = new OrdinalMapper.IdentityMapper(ravv.size() - 1);
var writerBuilder = new OnDiskGraphIndexWriter.Builder(graph, outputPath);
writerBuilder.withMapper(identityMapper);
writerBuilder.with(new InlineVectors(dimension));
writerBuilder.with(new FusedPQ(graph.maxDegree(), pq));
var writer = writerBuilder.build();

for (var node = 0; node < ravv.size(); node++) {
var stateMap = new EnumMap<FeatureId, Feature.State>(FeatureId.class);
stateMap.put(FeatureId.INLINE_VECTORS, writeSuppliers.get(FeatureId.INLINE_VECTORS).apply(node));
writer.writeInline(node, stateMap);
builder.addGraphNode(node, ravv.getVector(node));
}
builder.cleanup();

writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal));
writer.write(writeSuppliers);
return outputPath;
}

/**
* Exercises the parallel PQ-retraining path (issue #587): with many FusedPQ
* sources, {@code PQRetrainer.extractVectorsSequential} reads each source on
* its own thread/View. This test compacts more sources than the default
* {@code testCompact} uses, then verifies that every source's inline vectors
* survive the compaction exactly at their remapped ordinals — proving the
* concurrent extraction did not throw, deadlock, drop, or mis-assign vectors.
*/
@Test
public void testCompactManySourcesParallelRetrain() throws Exception {
final int sourceCount = 8;
final int nodesPerSource = 256;

List<OnDiskGraphIndex> graphs = new ArrayList<>();
List<ReaderSupplier> rss = new ArrayList<>();
List<FixedBitSet> liveNodes = new ArrayList<>();
List<OrdinalMapper> remappers = new ArrayList<>();
// expected[globalOrdinal] -> original vector
List<VectorFloat<?>> expected = new ArrayList<>();

int globalOrdinal = 0;
for (int s = 0; s < sourceCount; s++) {
List<VectorFloat<?>> vecs = createRandomVectors(nodesPerSource, dimension);
Path path = buildFusedPQSource(vecs, "parallel_retrain_src_" + s);

ReaderSupplier rs = ReaderSupplierFactory.open(path);
rss.add(rs);
graphs.add(OnDiskGraphIndex.load(rs));

Map<Integer, Integer> map = new HashMap<>(nodesPerSource);
for (int i = 0; i < nodesPerSource; i++) {
map.put(i, globalOrdinal++);
expected.add(vecs.get(i));
}
remappers.add(new OrdinalMapper.MapMapper(map));

FixedBitSet live = new FixedBitSet(nodesPerSource);
live.set(0, nodesPerSource);
liveNodes.add(live);
}

var compactor = new OnDiskGraphIndexCompactor(graphs, liveNodes, remappers, similarityFunction, null);
Path outPath = testDirectory.resolve("parallel_retrain_out");
compactor.compact(outPath);

ReaderSupplier rsOut = ReaderSupplierFactory.open(outPath);
OnDiskGraphIndex compacted = OnDiskGraphIndex.load(rsOut);
assertEquals("compacted graph must contain every source's nodes",
sourceCount * nodesPerSource, compacted.size(0));

var view = compacted.getView();
VectorFloat<?> buf = vectorTypeSupport.createFloatVector(dimension);
for (int ord = 0; ord < expected.size(); ord++) {
view.getVectorInto(ord, buf, 0);
assertVecEquals(expected.get(ord), buf, ord);
}
}

/**
* Tests compaction with deleted nodes.
* Verifies that deleted nodes are properly excluded from the compacted graph.
Expand Down
Loading