-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
118 lines (93 loc) · 3.63 KB
/
dataset.py
File metadata and controls
118 lines (93 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# FitLayout - Python GNN Demo
# (c) 2026 Radek Burget <burgetr@fit.vut.cz>
# Graph dataset implementations.
import os
from abc import ABC
import torch
import torch_geometric.data as data
from torch_geometric.data import Dataset, InMemoryDataset
class RemoteDataset(Dataset):
"""
A minimalistic dataset implementation that fetches page graphs directly from a remote FitLayout remote
server using the FitLayout client library.
"""
def __init__(self, creator, limit=None):
super().__init__()
self.creator = creator
self.iris = creator.get_artifact_iris()
if limit is not None:
self.iris = self.iris[:limit]
def len(self):
"""
Returns the number of samples in the dataset.
:return: The number of samples in the dataset
"""
return len(self.iris)
def get(self, idx):
"""
Returns the sample at the given index.
:param idx: The index of the sample to return
:return: The sample at the given index
"""
#print("Loading sample: ", self.iris[idx])
return self.creator.get_artifact_graph(self.iris[idx])
class LocalDataset(InMemoryDataset, ABC):
"""
A custom dataset implementation that loads graphs from a local directory.
"""
def __init__(self, dataset_path):
"""
Creates a new LocalDataset instance.
:param dataset_path: Input dataset path. The directory should contain *.pt files representing graphs.
"""
super().__init__()
cuda = torch.cuda.is_available()
# load graphs
self.graphs = []
for i, graph in enumerate(os.listdir(dataset_path)):
# Load only *.pt
if graph.endswith(".pt"):
if cuda:
g = torch.load(os.path.join(dataset_path, graph), weights_only=False)
else:
g = torch.load(os.path.join(dataset_path, graph), weights_only=False, map_location='cpu')
self.graphs.append(g)
self.data = self.get_batch(16)
def __len__(self):
"""
Returns the number of samples in the dataset.
:return: The number of samples in the dataset
"""
return len(self.graphs)
def __getitem__(self, item):
"""
Loads and returns a sample from the dataset at the given index.
:param item: Id of requested sample
:return: Sample at the given index
"""
return self.graphs[item]
def get_batch(self, batch_size):
"""
Returns batch object, representing multiple graphs as a single disconnected graph.
:param batch_size: Batch size
:return: batch
"""
return data.Batch().from_data_list(self.graphs[0:batch_size])
def get_cluster(self, n):
"""
Returns cluster object, grouping graphs into specific number of clusters
:param n: Number of graphs in one cluster
:return: cluster loader
"""
return data.ClusterLoader(data.ClusterData(self.graphs, n))
def get_sampler(self, item, num_neigh, batch_size, shuffle):
"""
Returns sampler which for each convolutional layer samples a max number of nodes from each neighborhood.
:param item: Index of graph
:param num_neigh: List of numbers of max neighbors for each convolution layer
:param batch_size: Number of graphs
:param shuffle: Randomly shuffle graphs
:return: Sampler
"""
return data.NeighborSampler(self.graphs[item].edge_index, sizes=num_neigh, batch_size=batch_size,
shuffle=shuffle)