-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdistributed_kann.py
More file actions
executable file
·158 lines (115 loc) · 4.87 KB
/
distributed_kann.py
File metadata and controls
executable file
·158 lines (115 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#! /usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This file is an edited version of the original, used to experiment w/ Faiss and distributed
# sharding.
"""
Simple distributed k-ann implementation, based off of distributed_kmeans.py
Also relies on an abstraction for the training matrix that can be sharded over several machines.
"""
"""
FLOW
Method 1 - random sharding, test vector search.
"""
import os
import sys
import argparse
from datetime import datetime
import numpy as np
import faiss
from faiss.contrib.vecs_io import bvecs_mmap, fvecs_mmap
from DatasetAssignCustom import *
def do_test(testdata, todo):
k_search = 10
if os.path.exists(testdata):
print("Mmapping vecs")
x = bvecs_mmap(testdata).astype("float32")
print("Mmapping vecs done!")
else:
print("Dataset not real, exiting")
sys.exit(1)
# first we shuffle vecs to ensure that if we do random partitioning, it is actually random
x = x[:10_000_000]
np.random.shuffle(x)
print(f"Testing over {x.shape[0]} vectors in R^{x.shape[1]}, 10000 queries")
# assuming bigann sift1b, we can select 10000 vecs to use as queries.
queries = x[np.random.choice(x.shape[0], 1_000, replace=False)]
start_time = datetime.now()
if "search-cpu-flat" in todo:
print("Testing brute force k-ANN search")
index = faiss.IndexFlatL2(x.shape[1])
index.add(x.astype('float32'))
D, I = index.search(queries, k_search)
print(f"Reference search complete!")
if "search-cpu-shard" in todo:
# num_shards = os.cpu_count()
num_shards = 10
print(f"Testing distributed-kANN over {num_shards} CPU shards")
# by default split into $(nproc) shards, test each.
data = DatasetAssignDispatch([
DatasetAssignCustomIndex(x[(x.shape[0] // num_shards) * i : (x.shape[0] // num_shards) * (i + 1)], i)
for i in range(num_shards)
], True)
print("Starting search")
D, I = data.search(queries, k_search)
print("Distributed CPU search complete")
if "search-gpu-shard" in todo:
print("Testing k-ANN - gpu sharding")
ngpus = faiss.get_num_gpus()
if ngpus > 0:
print(f"Sharding over {ngpus} gpus")
data = DatasetAssignDispatch([
DatasetAssignGPUCustomIndex(x[x.shape[0] * i // ngpus: x.shape[0] * (i + 1) // ngpus], i) for i in range(ngpus)
], True)
D, I = data.search(queries, k_search)
print("GPU sharding search complete!")
else:
print("No gpus available, must skip")
print("Testing accuracy...")
recall_1 = 0
recall_10 = 0
for i in range(queries.shape[0]):
query = queries[i]
top_k_indices = I[i]
top_k_vecs = x[top_k_indices]
# check recall@1 (simple)
if (top_k_vecs[0] == query).all():
recall_1 += 1
# check recall@10 (also simple)
# transpose to make it easier for my head
top_k_vecs = top_k_vecs.T # shape (128, 10)
query = query.reshape((128, 1))
if np.all(query == top_k_vecs, axis=0).any():
recall_10 += 1
print(f"Recall@1: {recall_1 / float(queries.shape[0])}")
print(f"Recall@10: {recall_10 / float(queries.shape[0])}")
print(f"Total time taken: {(datetime.now() - start_time).total_seconds()} seconds")
def main():
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('general options')
aa('--test', default='', help='perform tests (search-gpu-flat, search-gpu-shard, search-cpu-shard)')
aa('--k', default=0, type=int, help='nb centroids')
aa('--seed', default=1234, type=int, help='random seed')
aa('--gpu', default=-2, type=int, help='GPU to use (-2:none, -1: all)')
group = parser.add_argument_group('I/O options')
aa('--indata', default='',
help='data file to load (supported formats fvecs, bvecs, npy')
group = parser.add_argument_group('server options')
aa('--server', action='store_true', default=False, help='run server')
aa('--port', default=12345, type=int, help='server port')
aa('--when_ready', default=None, help='store host:port to this file when ready')
aa('--ipv4', default=False, action='store_true', help='force ipv4')
group = parser.add_argument_group('client options')
aa('--client', action='store_true', default=False, help='run client')
aa('--servers', default='', help='list of server:port separated by spaces')
args = parser.parse_args()
if args.test:
do_test(args.indata, args.test.split(','))
return
## TODO: make truly distributed
if __name__ == '__main__':
main()