Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/cora/cora/raw/ind.cora.allx
Binary file not shown.
Binary file added data/cora/cora/raw/ind.cora.tx
Binary file not shown.
Binary file added data/cora/cora/raw/ind.cora.x
Binary file not shown.
38 changes: 24 additions & 14 deletions examples/gat_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
'''
Author: lusz
Date: 2024-07-01 15:42:52
Description:
Modified for multi-head attention
Description: GAT example on Cora dataset
'''
import os.path as osp
import argparse
Expand All @@ -26,6 +27,7 @@
jt.flags.lazy_execution = 0
# jt.misc.set_global_seed(42)
jt.cudnn.set_max_workspace_ratio(0.0)

# Edge normalization for GCN
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
Expand Down Expand Up @@ -57,7 +59,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
parser.add_argument('--dataset', help='graph dataset')
parser.add_argument('--dataset', default='cora', help='graph dataset')
args = parser.parse_args()
dataset=args.dataset

Expand Down Expand Up @@ -89,29 +91,37 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
data.csc = cootocsc(edge_index, edge_weight, v_num)
data.csr = cootocsr(edge_index, edge_weight, v_num)

# GAT model with two attention layers
# GAT model with multi-head attention
# Standard configuration: 8 heads for layer 1, 1 head for layer 2
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GATConv(dataset.num_features,128,e_num, cached=True,
normalize=not args.use_gdc)
self.conv2 = GATConv(128, dataset.num_classes,e_num, cached=True,
normalize=not args.use_gdc)
# First layer: 8 attention heads, each with 8 output features
# concat=True means we concatenate the outputs of all heads
self.conv1 = GATConv(dataset.num_features, 8, e_num,
heads=8, concat=True,
negative_slope=0.2, dropout=0.6,
cached=True, normalize=not args.use_gdc)
# Second layer: 1 attention head, takes concatenated features
self.conv2 = GATConv(8 * 8, dataset.num_classes, e_num,
heads=1, concat=False,
negative_slope=0.2, dropout=0.6,
cached=True, normalize=not args.use_gdc)

def execute(self):
x, csc =data.x , data.csc
x = nn.relu(self.conv1(x, csc))
x = nn.dropout(x)
x = nn.relu(self.conv2(x,csc))
x, csc = data.x, data.csc
x = nn.dropout(x, p=0.6)
x = nn.elu(self.conv1(x, csc)) # ELU is commonly used with GAT
x = nn.dropout(x, p=0.6)
x = self.conv2(x, csc)
return nn.log_softmax(x, dim=1)


# Initialize model and optimizer
model, data = Net(), data
optimizer = nn.Adam([
dict(params=model.conv1.parameters(), weight_decay=1e-4),
dict(params=model.conv2.parameters(), weight_decay=1e-4)
], lr=5e-3)
dict(params=model.parameters(), weight_decay=5e-4)
], lr=0.005)


# Training function
Expand Down
164 changes: 164 additions & 0 deletions examples/gatv2_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
'''
Author: lusz
Date: 2024-07-01 15:42:52
Modified for GATv2 with multi-head attention
Description: GATv2 example on Cora dataset
'''
import os.path as osp
import argparse

import jittor as jt
from jittor import nn
import sys,os
root = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(root)
from jittor_geometric.datasets import Planetoid
import jittor_geometric.transforms as T
from jittor_geometric.nn import GATv2Conv
import time
from jittor import Var
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes
from jittor_geometric.data import CSC,CSR
from jittor_geometric.ops import cootocsr,cootocsc

# Setup configuration
jt.flags.use_cuda = 1
jt.flags.lazy_execution = 0
# jt.misc.set_global_seed(42)
jt.cudnn.set_max_workspace_ratio(0.0)

# Edge normalization for GCN
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):

fill_value = 2. if improved else 1.

if isinstance(edge_index, Var):
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
edge_weight = jt.ones((edge_index.size(1), ))

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
shape = list(edge_weight.shape)
shape[0] = num_nodes
deg = jt.zeros(shape)
deg = jt.scatter(deg, 0, col, src=edge_weight, reduce='add')
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
parser.add_argument('--dataset', default='cora', help='graph dataset')
args = parser.parse_args()
dataset=args.dataset

# Load dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '../data')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

# Apply GDC preprocessing if requested
if args.use_gdc:
gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
normalization_out='col',
diffusion_kwargs=dict(method='ppr', alpha=0.05),
sparsification_kwargs=dict(method='topk', k=128,
dim=0), exact=True)
data = gdc(data)

# Prepare data and edge normalization
v_num = data.x.shape[0]
e_num = data.edge_index.shape[1]
edge_index, edge_weight=data.edge_index,data.edge_attr

edge_index, edge_weight = gcn_norm(
edge_index, edge_weight,v_num,
False, True)

# Convert to sparse matrix format
with jt.no_grad():
data.csc = cootocsc(edge_index, edge_weight, v_num)
data.csr = cootocsr(edge_index, edge_weight, v_num)

# GATv2 model with multi-head attention
# Standard configuration: 8 heads for layer 1, 1 head for layer 2
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# First layer: 8 attention heads, each with 8 output features
# concat=True means we concatenate the outputs of all heads
self.conv1 = GATv2Conv(dataset.num_features, 8, e_num,
heads=8, concat=True,
negative_slope=0.2, dropout=0.6,
cached=True, normalize=not args.use_gdc)
# Second layer: 1 attention head, takes concatenated features
self.conv2 = GATv2Conv(8 * 8, dataset.num_classes, e_num,
heads=1, concat=False,
negative_slope=0.2, dropout=0.6,
cached=True, normalize=not args.use_gdc)

def execute(self):
x, csc = data.x, data.csc
x = nn.dropout(x, p=0.6)
x = nn.elu(self.conv1(x, csc)) # ELU is commonly used with GAT
x = nn.dropout(x, p=0.6)
x = self.conv2(x, csc)
return nn.log_softmax(x, dim=1)


# Initialize model and optimizer
model, data = Net(), data
optimizer = nn.Adam([
dict(params=model.parameters(), weight_decay=5e-4)
], lr=0.005)


# Training function
def train():
model.train()

pred = model()[data.train_mask]
label = data.y[data.train_mask]
loss = nn.nll_loss(pred, label)
optimizer.step(loss)

# Evaluation function
def test():
model.eval()
logits, accs = model(), []
# Evaluate on train, val, test sets
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
y_ = data.y[mask]
logits_=logits[mask]
pred, _ = jt.argmax(logits_, dim=1)
acc = pred.equal(y_).sum().item() / mask.sum().item()
accs.append(acc)
return accs



# Training loop
train()
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train()
train_acc, val_acc, tmp_test_acc = test()
# Track best validation accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
print(log.format(epoch, train_acc, best_val_acc, test_acc))
jt.sync_all()
jt.gc()
Loading