Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.apple.foundationdb.Database;
import com.apple.foundationdb.Transaction;
import com.apple.foundationdb.async.AsyncUtil;
import com.apple.foundationdb.linear.AffineOperator;
import com.apple.foundationdb.linear.DoubleRealVector;
import com.apple.foundationdb.linear.HalfRealVector;
Expand All @@ -31,11 +32,11 @@
import com.apple.foundationdb.linear.StoredVecsIterator;
import com.apple.foundationdb.linear.Transformed;
import com.apple.foundationdb.subspace.Subspace;
import com.apple.foundationdb.test.ThrottledRetryingRunner;
import com.apple.foundationdb.tuple.Tuple;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
Expand Down Expand Up @@ -68,8 +69,11 @@
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.function.Consumer;
Expand Down Expand Up @@ -97,79 +101,118 @@ static void dumpQueryResults(@Nonnull final Path tempDir, @Nonnull final String
}
}

@Nonnull
static CompletableFuture<List<PrimaryKeyAndVector>> basicInsertBatch(
@Nonnull final Transaction tr,
@Nonnull final HNSW hnsw,
final int batchSize,
final long firstId,
@Nonnull final BiFunction<Transaction, Long, PrimaryKeyAndVector> insertFunction) {
logger.info("Inserting batch starting at " + firstId);
final TestOnWriteListener onWriteListener = (TestOnWriteListener)hnsw.getOnWriteListener();
onWriteListener.reset();
final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener();
onReadListener.reset();

final ImmutableList.Builder<PrimaryKeyAndVector> data = ImmutableList.builder();
// Call insertFunction lazily between async inserts so that shouldContinue() sees the
// actual elapsed time after each insert rather than evaluating all checks synchronously
// before any async work begins.
final AtomicInteger insertCount = new AtomicInteger(0);
return AsyncUtil.whileTrue(() -> {
final int i = insertCount.get();
if (i >= batchSize) {
return AsyncUtil.READY_FALSE;
}
final PrimaryKeyAndVector record = insertFunction.apply(tr, firstId + i);
if (record == null) {
return AsyncUtil.READY_FALSE;
}
data.add(record);
insertCount.incrementAndGet();
return hnsw.insert(tr, record.getPrimaryKey(), record.getVector())
.thenApply(ignored -> Boolean.TRUE);
}).<List<PrimaryKeyAndVector>>thenApply(ignored -> data.build());
}

@Nonnull
static List<PrimaryKeyAndVector> basicInsertBatch(@Nonnull final Database db,
@Nonnull final HNSW hnsw,
final int batchSize,
final long firstId,
@Nonnull final BiFunction<Transaction, Long, PrimaryKeyAndVector> insertFunction)
throws ExecutionException, InterruptedException, TimeoutException {
return db.runAsync(tr -> {
final TestOnWriteListener onWriteListener = (TestOnWriteListener)hnsw.getOnWriteListener();
onWriteListener.reset();
final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener();
onReadListener.reset();

final ImmutableList.Builder<PrimaryKeyAndVector> data = ImmutableList.builder();

// In theory this could put all the futures in a List and run the inserts concurrently, but for a `basicInsertBatch`
// it's probably better to not test the concurrent handling of hnsw, even if it makes the tests slower.
CompletableFuture<Void> future = CompletableFuture.completedFuture(null);
final long beginTs = System.nanoTime();
for (int i = 0; i < batchSize; i ++) {
final PrimaryKeyAndVector record = insertFunction.apply(tr, firstId + i);
if (record == null) {
break;
}
data.add(record);
future = future.thenCompose((vignore) -> hnsw.insert(tr, record.getPrimaryKey(), record.getVector()));
}
return future.thenApply(vignore -> data.build())
.whenComplete((result, error) -> {
if (error != null) {
logger.info("Failed to insert batchSize={}", error);
} else {
final long endTs = System.nanoTime();
logger.info("inserted batchSize={} records={} starting at nodeId={} took elapsedTime={}ms, readCounts={}, readBytes={}",
batchSize, result.size(), firstId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs),
onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer());
}
});
}).get(2, TimeUnit.MINUTES); // set a timeout for inserting a single batch including retries so setup won't run forever
return db.runAsync(
tr -> {
final long beginTs = System.nanoTime();
return basicInsertBatch(tr, hnsw, batchSize, firstId, insertFunction)
.whenComplete((result, error) -> {
if (error != null) {
logger.info("Failed to insert batchSize={}", error);
} else {
final long endTs = System.nanoTime();
final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener();
logger.info("inserted batchSize={} records={} starting at nodeId={} took elapsedTime={}ms, readCounts={}, readBytes={}",
batchSize, result.size(), firstId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs),
onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer());
}
});
})
.get(2, TimeUnit.MINUTES); // set a timeout for inserting a single batch including retries so setup won't run forever
}

static List<PrimaryKeyAndVector> insertSIFTSmall(@Nonnull final Database db,
@Nonnull final HNSW hnsw) throws Exception {
final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs");

final ImmutableList.Builder<PrimaryKeyAndVector> insertedDataBuilder = ImmutableList.builder();

// Load all vectors upfront so the task can index into them by position using the adaptive limit.
final List<DoubleRealVector> allVectors;
try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) {
final Iterator<DoubleRealVector> vectorIterator = new StoredVecsIterator.StoredFVecsIterator(fileChannel);

final int batchSize = 50;
int i = 0;
while (vectorIterator.hasNext()) {
final List<DoubleRealVector> batch =
Lists.newArrayList(Iterators.limit(vectorIterator, batchSize));
final long currentBatchStart = i;
final List<PrimaryKeyAndVector> insertedInBatch =
basicInsertBatch(db, hnsw, batchSize, i,
(tr, nextId) -> {
final int indexInBatch = Math.toIntExact(nextId - currentBatchStart);
if (indexInBatch >= batch.size()) {
return null;
}
final Tuple currentPrimaryKey = createPrimaryKey(nextId);
final DoubleRealVector doubleVector = batch.get(indexInBatch);
return new PrimaryKeyAndVector(currentPrimaryKey, doubleVector);
});
insertedDataBuilder.addAll(insertedInBatch);
i += insertedInBatch.size();
allVectors = Lists.newArrayList(new StoredVecsIterator.StoredFVecsIterator(fileChannel));
}

// A continuation that carries the next vector index to insert.
class SiftContinuation implements ThrottledRetryingRunner.Continuation {
final int nextIndex;

SiftContinuation(int nextIndex) {
this.nextIndex = nextIndex;
}

@Override
public boolean hasMore() {
return nextIndex < allVectors.size();
}
assertThat(i).isEqualTo(10000);
}
return insertedDataBuilder.build();

final ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
try (ThrottledRetryingRunner runner = ThrottledRetryingRunner.builder(db, executor)
.withMaxLimit(500)
.build()) {
runner.iterateAll((tr, quota, cont) -> {
final int startIdx = (cont instanceof SiftContinuation)
? ((SiftContinuation) cont).nextIndex : 0;
return basicInsertBatch(tr, hnsw, quota.getLimit(), startIdx,
(ignoredTr, nextId) -> {
final int idx = Math.toIntExact(nextId);
if (idx < allVectors.size() && quota.shouldContinue()) {
quota.processedCountInc();
return new PrimaryKeyAndVector(createPrimaryKey(idx), allVectors.get(idx));
}
return null;
})
.thenApply(list -> new SiftContinuation(startIdx + list.size()));
}).join();
} finally {
executor.shutdown();
}

assertThat(allVectors).hasSize(10000);
final ImmutableList.Builder<PrimaryKeyAndVector> result = ImmutableList.builder();
for (int i = 0; i < allVectors.size(); i++) {
result.add(new PrimaryKeyAndVector(createPrimaryKey(i), allVectors.get(i)));
}
return result.build();
}

static void validateSIFTSmall(@Nonnull final Database db,
Expand Down
Loading
Loading