Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions examples/gatv2_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
'''
Author: Assistant
Date: 2026-02-26
Description: GATv2 Example
'''
import os.path as osp
import argparse

import jittor as jt
from jittor import nn
import sys,os
root = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(root)
from jittor_geometric.datasets import Planetoid
import jittor_geometric.transforms as T
from jittor_geometric.nn import GATV2Conv
import time
from jittor import Var
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes
from jittor_geometric.data import CSC,CSR
from jittor_geometric.ops import cootocsr,cootocsc

# Setup configuration
jt.flags.use_cuda = 1
jt.flags.lazy_execution = 0
# jt.misc.set_global_seed(42)
jt.cudnn.set_max_workspace_ratio(0.0)
# Edge normalization for GCN
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):

fill_value = 2. if improved else 1.

if isinstance(edge_index, Var):
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
edge_weight = jt.ones((edge_index.size(1), ))

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
shape = list(edge_weight.shape)
shape[0] = num_nodes
deg = jt.zeros(shape)
deg = jt.scatter(deg, 0, col, src=edge_weight, reduce='add')
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
parser.add_argument('--dataset', help='graph dataset')
args = parser.parse_args()
dataset=args.dataset

# Load dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '../data')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

# Apply GDC preprocessing if requested
if args.use_gdc:
gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
normalization_out='col',
diffusion_kwargs=dict(method='ppr', alpha=0.05),
sparsification_kwargs=dict(method='topk', k=128,
dim=0), exact=True)
data = gdc(data)

# Prepare data and edge normalization
v_num = data.x.shape[0]
e_num = data.edge_index.shape[1]
edge_index, edge_weight=data.edge_index,data.edge_attr

edge_index, edge_weight = gcn_norm(
edge_index, edge_weight,v_num,
False, True)

# Convert to sparse matrix format
with jt.no_grad():
data.csc = cootocsc(edge_index, edge_weight, v_num)
data.csr = cootocsr(edge_index, edge_weight, v_num)

# GATv2 model with two attention layers
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GATV2Conv(dataset.num_features,128,e_num, cached=True,
normalize=not args.use_gdc)
self.conv2 = GATV2Conv(128, dataset.num_classes,e_num, cached=True,
normalize=not args.use_gdc)

def execute(self):
x, csc =data.x , data.csc
x = self.conv1(x, csc)
x = nn.relu(x)
x = nn.dropout(x)
x = self.conv2(x,csc)
return nn.log_softmax(x, dim=1)


# Initialize model and optimizer
model, data = Net(), data
optimizer = nn.Adam([
dict(params=model.conv1.parameters(), weight_decay=5e-4),
dict(params=model.conv2.parameters(), weight_decay=5e-4)
], lr=1e-2)


# Training function
def train():
model.train()

pred = model()[data.train_mask]
label = data.y[data.train_mask]
loss = nn.nll_loss(pred, label)
optimizer.step(loss)

# Evaluation function
def test():
model.eval()
logits, accs = model(), []
# Evaluate on train, val, test sets
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
y_ = data.y[mask]
logits_=logits[mask]
pred, _ = jt.argmax(logits_, dim=1)
acc = pred.equal(y_).sum().item() / mask.sum().item()
accs.append(acc)
return accs



# Training loop
train()
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train()
train_acc, val_acc, tmp_test_acc = test()
# Track best validation accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
print(log.format(epoch, train_acc, best_val_acc, test_acc))
jt.sync_all()
jt.gc()
Loading