-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconstruct_graph.py
More file actions
83 lines (68 loc) · 2.8 KB
/
construct_graph.py
File metadata and controls
83 lines (68 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import json
from utils import *
import torch
from transformers import CLIPTokenizer, CLIPModel, CLIPTextModel
import tqdm
import networkx as nx
import pandas as pd
import numpy as np
import sys
def main(colabfp=False):
# WARNING: cpu would be very slow
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer, text_encoder, model = get_models(device)
data_path = r"data/raw"
data_file = "embeddings.pd"
dataset_names = ["artists.txt", "mediums.txt", "movements.txt", "flavors"]
output_dir = r"data/processed"
output_file_name = "graph.csv"
output_id_name = "node_ids.json"
prompt_data = r"prompt data/data/train-00000-of-00001.parquet"
if colabfp:
colab_dir = r"/content/ILS-Data-Project-2023/"
data_path = colab_dir + data_path
output_dir = colab_dir + output_dir
if not os.path.exists(os.path.join(output_dir, data_file)):
print("Could not find embeddings file, please run create_embeddings.py first")
return
embedding_file = open(os.path.join(output_dir, data_file), "rb")
node_dict = torch.load(embedding_file)
threshold = 0.7 # 0.85
num_samples = 4000
nodes = {}
node_ids = {}
ids_categories = {}
curr_id = 0
for file_cont in dataset_names:
ids_categories[file_cont] = []
for i, n in enumerate(node_dict[file_cont]):
nodes[curr_id] = n[1].unsqueeze(0).to(device)
node_ids[curr_id] = n[0]
ids_categories[file_cont].append(curr_id)
curr_id += 1
if not os.path.exists(os.path.join(output_dir, output_id_name)):
json_dict = {
"name" : output_id_name
}
json_dict["ids"] = node_ids
json_dict["categories"] = ids_categories
with open(os.path.join(output_dir, output_id_name), "w", encoding="utf-8") as outfile:
json.dump(json_dict, outfile)
data_frame = pd.read_parquet(os.path.join(data_path, prompt_data))
print(data_frame.info())
prompts = data_frame.sample(num_samples)
# graph = construct_graph(nodes, prompts, threshold, tokenizer, text_encoder, model, device)
# node_list, adjacency_mat = graph_to_adjacency_matrix(graph)
# much faster by utilizing tensor operations and directly computing the adjacency matrix:
adjacency_mat = construct_graph_adjacency(nodes, prompts, threshold, tokenizer, text_encoder, model, device)
print("saving ... ", end="")
# torch.save(adjacency_mat, open(os.path.join(output_dir, output_file_name), "wb"))
np.savetxt(os.path.join(output_dir, output_file_name), adjacency_mat, delimiter=',')
print("done")
if __name__ == "__main__":
use_colab_fp = False
if "--colabfp" in sys.argv:
print("using --colabfp")
use_colab_fp = True
main(colabfp=use_colab_fp)