-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataset_loader.py
More file actions
73 lines (66 loc) · 3 KB
/
dataset_loader.py
File metadata and controls
73 lines (66 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import torch
from torch_geometric.datasets import Planetoid, WebKB, WikipediaNetwork, Amazon, Actor
from torch_geometric.data import Data, InMemoryDataset, download_url
from LINKX_dataset import LINKXDataset
from Hetero_dataset import HeteroDataset
from torch_geometric.utils import homophily, degree, assortativity, is_undirected, to_undirected
import torch_geometric.transforms as T
from torch_geometric.transforms import LargestConnectedComponents, ToUndirected, Compose
#from ogb.nodeproppred import PygNodePropPredDataset
print(111)
root = "/data/runlin_lei/data"
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default='cora')
args = parser.parse_args()
name = args.dataset
name = name.lower()
if name in ['cora', 'citeseer', 'pubmed']:
dataset = Planetoid(root=root, name=name, transform=T.NormalizeFeatures())
elif name in ['computers', 'photo']:
dataset = Amazon(root=root, name=name, transform=T.NormalizeFeatures())
elif name in ['cornell', 'texas', 'wisconsin']:
dataset = WebKB(root=root, name=name, transform=T.NormalizeFeatures())
elif name in ['chameleon', 'squirrel']:
# For directed setting
# dataset = WikipediaNetwork(root=root, name=name, transform=T.NormalizeFeatures())
# For GPRGNN-like dataset, use everything from "geom_gcn_preprocess=False" and
# only the node label y from "geom_gcn_preprocess=True"
preProcDs = WikipediaNetwork(
root=root, name=name, geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
dataset = WikipediaNetwork(
root=root, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
data = dataset[0]
data.edge_index = preProcDs[0].edge_index
elif name in ['film', 'actor']:
root += '/actor/'
dataset = Actor(root=root, transform=T.NormalizeFeatures())
elif name in ['penn94', 'genius', 'wiki', 'pokec', 'arxiv-year',
'twitch-gamer', 'snap-patents', 'twitch-de', 'deezer-europe']:
dataset = LINKXDataset(root=root, name=name)
if name != 'arxiv-year' and name != 'snap-patents':
dataset.data['edge_index'] = to_undirected(dataset.data['edge_index'])
elif name in ['ogbn-arxiv', 'ogbn-products', 'ogbn-papers']:
if name == 'ogbn-papers':
dataset = PygNodePropPredDataset(name='ogbn-papers100M', root=root)
else:
dataset = PygNodePropPredDataset(name=name, root=root)
elif name in ['roman_empire', 'amazon_ratings', 'questions', 'minesweeper', 'tolokers']:
dataset = HeteroDataset(root=root, name=name, transform=ToUndirected())
# test
data = dataset[0]
print(data)
print(data.x.shape[0])
edge_index = data.edge_index
print(data.edge_index.shape[1])
print(data.x.shape[1])
print(dataset.num_classes)
undirected = is_undirected(edge_index)
print(undirected)
homo = homophily(data.edge_index, data.y)
homo_node = homophily(data.edge_index, data.y, method='node')
print(f"{homo:.3f}", f"{homo_node:.3f}")
ass = assortativity(data.edge_index)
print(f"{ass:.3f}")
deg = data.edge_index.shape[1] / data.x.shape[0]
print(f"{deg:.3f}")