forked from AlgRUC/JittorGeometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgat_example.py
More file actions
154 lines (130 loc) · 4.73 KB
/
gat_example.py
File metadata and controls
154 lines (130 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
'''
Author: lusz
Date: 2024-07-01 15:42:52
Description:
'''
import os.path as osp
import argparse
import jittor as jt
from jittor import nn
import sys,os
root = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(root)
from jittor_geometric.datasets import Planetoid
import jittor_geometric.transforms as T
from jittor_geometric.nn import GATConv
import time
from jittor import Var
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes
from jittor_geometric.data import CSC,CSR
from jittor_geometric.ops import cootocsr,cootocsc
# Setup configuration
jt.flags.use_cuda = 1
jt.flags.lazy_execution = 0
# jt.misc.set_global_seed(42)
jt.cudnn.set_max_workspace_ratio(0.0)
# Edge normalization for GCN
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
fill_value = 2. if improved else 1.
if isinstance(edge_index, Var):
num_nodes = maybe_num_nodes(edge_index, num_nodes)
if edge_weight is None:
edge_weight = jt.ones((edge_index.size(1), ))
if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight
row, col = edge_index[0], edge_index[1]
shape = list(edge_weight.shape)
shape[0] = num_nodes
deg = jt.zeros(shape)
deg = jt.scatter(deg, 0, col, src=edge_weight, reduce='add')
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
parser.add_argument('--dataset', help='graph dataset')
args = parser.parse_args()
dataset=args.dataset
# Load dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '../data')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
# Apply GDC preprocessing if requested
if args.use_gdc:
gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
normalization_out='col',
diffusion_kwargs=dict(method='ppr', alpha=0.05),
sparsification_kwargs=dict(method='topk', k=128,
dim=0), exact=True)
data = gdc(data)
# Prepare data and edge normalization
v_num = data.x.shape[0]
e_num = data.edge_index.shape[1]
edge_index, edge_weight=data.edge_index,data.edge_attr
edge_index, edge_weight = gcn_norm(
edge_index, edge_weight,v_num,
False, True)
# Convert to sparse matrix format
with jt.no_grad():
data.csc = cootocsc(edge_index, edge_weight, v_num)
data.csr = cootocsr(edge_index, edge_weight, v_num)
# GAT model with two attention layers
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GATConv(dataset.num_features,128,e_num, cached=True,
normalize=not args.use_gdc)
self.conv2 = GATConv(128, dataset.num_classes,e_num, cached=True,
normalize=not args.use_gdc)
def execute(self):
x, csc =data.x , data.csc
x = nn.relu(self.conv1(x, csc))
x = nn.dropout(x)
x = nn.relu(self.conv2(x,csc))
return nn.log_softmax(x, dim=1)
# Initialize model and optimizer
model, data = Net(), data
optimizer = nn.Adam([
dict(params=model.conv1.parameters(), weight_decay=1e-4),
dict(params=model.conv2.parameters(), weight_decay=1e-4)
], lr=5e-3)
# Training function
def train():
model.train()
pred = model()[data.train_mask]
label = data.y[data.train_mask]
loss = nn.nll_loss(pred, label)
optimizer.step(loss)
# Evaluation function
def test():
model.eval()
logits, accs = model(), []
# Evaluate on train, val, test sets
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
y_ = data.y[mask]
logits_=logits[mask]
pred, _ = jt.argmax(logits_, dim=1)
acc = pred.equal(y_).sum().item() / mask.sum().item()
accs.append(acc)
return accs
# Training loop
train()
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train()
train_acc, val_acc, tmp_test_acc = test()
# Track best validation accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
print(log.format(epoch, train_acc, best_val_acc, test_acc))
jt.sync_all()
jt.gc()