-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsampler.py
More file actions
89 lines (77 loc) · 3.71 KB
/
sampler.py
File metadata and controls
89 lines (77 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import threading
import Queue
def sample_function(user_train, user_validation, user_test, owner, usernum, itemnum, batch_size, train_queue,
valid_queue, test_queue, SEED):
def sample_ui():
if not is_test:
user = np.random.randint(0, usernum)
while len(User[user]['consume']) == 0: user = np.random.randint(0, usernum)
else:
user = np.random.randint(0, usernum)
while len(User_test[user]['consume']) == 0: user = np.random.randint(0, usernum)
num_item = len(User[user]['consume'])
# find postive item pair
if not is_test:
item_i = np.random.randint(0, num_item)
item_i = User[user]['consume'][item_i]
else:
item_i = User_test[user]['consume'][0]
# find negtive item
item_ip = np.random.randint(0, itemnum)
while item_ip in User[user]['consume'] or item_ip == item_i: item_ip = np.random.randint(0, itemnum)
return user, item_i, item_ip, owner[item_i], owner[item_ip]
np.random.seed(SEED)
User = user_train
while True:
if not train_queue.full():
is_test = False
User_test = []
one_batch = []
for i in range(batch_size):
batch = sample_ui()
one_batch.append(batch)
train_queue.put_nowait(zip(*one_batch))
if not valid_queue.full():
is_test = True
User_test = user_validation
one_batch = []
for i in range(batch_size):
batch = sample_ui()
one_batch.append(batch)
valid_queue.put_nowait(zip(*one_batch))
if not test_queue.full():
is_test = True
User_test = user_test
one_batch = []
for i in range(batch_size):
batch = sample_ui()
one_batch.append(batch)
test_queue.put_nowait(zip(*one_batch))
class WarpSampler(object):
def __init__(self, user_train, user_validation, user_test, owner, usernum, itemnum, batch_size=10000, n_workers=2):
self.train_queue = Queue.Queue(maxsize=n_workers)
self.valid_queue = Queue.Queue(maxsize=n_workers)
self.test_queue = Queue.Queue(maxsize=n_workers)
self.threads = []
for i in range(n_workers):
self.threads.append(threading.Thread(target=sample_function, args=(user_train,
user_validation,
user_test,
owner,
usernum,
itemnum,
batch_size,
self.train_queue,
self.valid_queue,
self.test_queue,
np.random.randint(2e9),
)))
self.threads[-1].daemon = True
self.threads[-1].start()
def next_train_batch(self):
return self.train_queue.get()
def next_valid_batch(self):
return self.valid_queue.get()
def next_test_batch(self):
return self.test_queue.get()