-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisual.py
More file actions
69 lines (59 loc) · 2.92 KB
/
visual.py
File metadata and controls
69 lines (59 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dataset.MarcoPassageDataset import *
from model.BertEmb import *
from util import *
def stop(word):
return False
stopwords = ['.', ',', '"', "'", ';', ':', '?', '!']
if word in stopwords:
return True
if len(word) <= 2:
return True
return False
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', required=False, type=int, default=5)
parser.add_argument('--batch_size_test', required=False, type=int, default=1)
parser.add_argument('--cls', required=False, type=int, default=-1)
parser.add_argument('--margin_const', required=False, type=float, default=10.0)
parser.add_argument('--max_plen', required=False, type=int, default=400)
parser.add_argument('--max_qlen', required=False, type=int, default=20)
parser.add_argument('--seed', required=False, type=int, default=1234)
parser.add_argument('--sep', required=False, type=int, default=-1)
parser.add_argument('--topic_dim', required=False, type=int, default=500)
parser.add_argument('--total_epoch', required=False, type=int, default=10)
args = parser.parse_args()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
if args.cls == -1:
args.cls = tokenizer.vocab['[CLS]']
if args.sep == -1:
args.sep = tokenizer.vocab['[SEP]']
dataset = MarcoPassageDataset()
model = BertEmb()
model = model.to(device)
state_dict = torch.load('save/1.0000.pt')['model']
model.load_state_dict(state_dict)
q_tok = tokenizer.tokenize(dataset.queries_dev['58'])[:args.max_qlen]
q_ids = [args.cls] + tokenizer.convert_tokens_to_ids(q_tok) + [args.sep]
q_seg = [0 for _ in range(len(q_ids))]
q_mask = [1 for _ in range(len(q_ids))]
p_tok = tokenizer.tokenize(dataset.passages['7571934'])[:args.max_plen]
p_ids = [args.cls] + tokenizer.convert_tokens_to_ids(p_tok) + [args.sep]
p_seg = [1 for _ in range(len(p_ids))]
p_mask = [1 for _ in range(len(p_ids))]
q_ids = torch.tensor([q_ids], dtype=torch.long).to(device)
q_seg = torch.tensor([q_seg], dtype=torch.long).to(device)
q_mask = torch.tensor([q_mask], dtype=torch.long).to(device)
p_ids = torch.tensor([p_ids], dtype=torch.long).to(device)
p_seg = torch.tensor([p_seg], dtype=torch.long).to(device)
p_mask = torch.tensor([p_mask], dtype=torch.long).to(device)
emb_q, seq_q = model(q_ids, q_seg, q_mask, out_seq=True)
emb_p, seq_p = model(p_ids, p_seg, p_mask, out_seq=True)
att = torch.matmul(emb_q, seq_p.permute(0, 2, 1)).view(-1).tolist()
ranks = sorted([(att[i + 1], p_tok[i]) for i in range(len(p_tok)) if not stop(p_tok[i])], reverse=True)
print('Q:', ' '.join(q_tok))
print('P:', ' '.join(p_tok))
print(ranks[:10])
with open('temp.txt', 'w', encoding='utf-8') as fp:
fp.write(' '.join(q_tok) + '\n')
fp.write(' '.join(p_tok) + '\n')
fp.write(' '.join(list(map(str, att[1:-1]))) + '\n')