Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions examples/gatv2_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
'''
Author: AI Assistant
Date: 2024-03-10
Description: GAT_V2 example using Cora dataset
'''
import os.path as osp
import argparse

import jittor as jt
from jittor import nn
import sys, os
root = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(root)
from jittor_geometric.datasets import Planetoid
import jittor_geometric.transforms as T
from jittor_geometric.nn import GATV2Conv
import time
from jittor import Var
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes
from jittor_geometric.data import CSC, CSR
from jittor_geometric.ops import cootocsr, cootocsc

# Setup configuration
jt.flags.use_cuda = 1
jt.flags.lazy_execution = 0
# jt.misc.set_global_seed(42)
jt.cudnn.set_max_workspace_ratio(0.0)

# Edge normalization for GCN (same as GAT)
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):

fill_value = 2. if improved else 1.

if isinstance(edge_index, Var):
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
edge_weight = jt.ones((edge_index.size(1), ))

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
shape = list(edge_weight.shape)
shape[0] = num_nodes
deg = jt.zeros(shape)
deg = jt.scatter(deg, 0, col, src=edge_weight, reduce='add')
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
parser.add_argument('--dataset', default='cora', help='graph dataset')
args = parser.parse_args()
dataset = args.dataset

# Load dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '../data')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

# Apply GDC preprocessing if requested
if args.use_gdc:
gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
normalization_out='col',
diffusion_kwargs=dict(method='ppr', alpha=0.05),
sparsification_kwargs=dict(method='topk', k=128,
dim=0), exact=True)
data = gdc(data)

# Prepare data and edge normalization (same as GAT)
v_num = data.x.shape[0]
e_num = data.edge_index.shape[1]
edge_index, edge_weight = data.edge_index, data.edge_attr

edge_index, edge_weight = gcn_norm(
edge_index, edge_weight, v_num,
False, True)

# Convert to sparse matrix format
with jt.no_grad():
data.csc = cootocsc(edge_index, edge_weight, v_num)
data.csr = cootocsr(edge_index, edge_weight, v_num)

# GAT_V2 model with two attention layers (same structure as GAT)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GATV2Conv(dataset.num_features, 128, e_num, cached=True,
normalize=not args.use_gdc)
self.conv2 = GATV2Conv(128, dataset.num_classes, e_num, cached=True,
normalize=not args.use_gdc)

def execute(self):
x, csc = data.x, data.csc
x = nn.relu(self.conv1(x, csc))
x = nn.dropout(x)
x = nn.relu(self.conv2(x, csc))
return nn.log_softmax(x, dim=1)


# Initialize model and optimizer
model, data = Net(), data
optimizer = nn.Adam([
dict(params=model.conv1.parameters(), weight_decay=1e-4),
dict(params=model.conv2.parameters(), weight_decay=1e-4)
], lr=5e-3)


# Training function
def train():
model.train()

pred = model()[data.train_mask]
label = data.y[data.train_mask]
loss = nn.nll_loss(pred, label)
optimizer.step(loss)

# Evaluation function
def test():
model.eval()
logits, accs = model(), []
# Evaluate on train, val, test sets
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
y_ = data.y[mask]
logits_ = logits[mask]
pred, _ = jt.argmax(logits_, dim=1)
acc = pred.equal(y_).sum().item() / mask.sum().item()
accs.append(acc)
return accs



# Training loop
train()
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train()
train_acc, val_acc, tmp_test_acc = test()
# Track best validation accuracy
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
print(log.format(epoch, train_acc, best_val_acc, test_acc))
jt.sync_all()
jt.gc()
4 changes: 3 additions & 1 deletion jittor_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .optbasis_conv import OptBasisConv
from .clustergcn_conv import ClusterGCNConv
from .sage_conv import SAGEConv
from .gatv2_conv import GATV2Conv

__all__ = [
'MessagePassing',
Expand All @@ -38,7 +39,8 @@
'TransformerConv',
'OptBasisConv',
'ClusterGCNConv',
'SAGEConv'
'SAGEConv',
'GATV2Conv'
]

classes = __all__
101 changes: 101 additions & 0 deletions jittor_geometric/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
'''
Description: GAT_V2 Convolutional Layer
Author: AI Assistant
Date: 2024-03-10
'''
from typing import Optional, Tuple
from jittor_geometric.typing import Adj, OptVar

import jittor as jt
from jittor import Var
from jittor_geometric.nn.conv import MessagePassingNts
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes

from ..inits import glorot, zeros
from jittor_geometric.data import CSC, CSR
from jittor_geometric.ops import ScatterToEdge, EdgeSoftmax, aggregateWithWeight, ScatterToVertex


class GATV2Conv(MessagePassingNts):
r"""The graph convolutional operator from the `"Graph Attention Networks v2"
<https://arxiv.org/abs/2105.14491>`_ paper.

The key difference from GAT is that GAT_V2 applies non-linear activation
before multiplying with attention vectors.
"""

_cached_edge_index: Optional[Tuple[Var, Var]]
_cached_csc: Optional[CSC]

def __init__(self, in_channels: int, out_channels: int, e_num: int,
improved: bool = False, cached: bool = False,
add_self_loops: bool = True, normalize: bool = True,
bias: bool = True, **kwargs):

kwargs.setdefault('aggr', 'add')
super(GATV2Conv, self).__init__(**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.normalize = normalize

self._cached_edge_index = None
self._cached_adj_t = None

self.weight = jt.random((in_channels, out_channels))
self.edge_weight = jt.random((2 * out_channels, 1))
self.reset_parameters()

def reset_parameters(self):
glorot(self.weight)
glorot(self.edge_weight)
self._cached_adj_t = None
self._cached_csc = None

def execute(self, x: Var, csc: CSC) -> Var:
""""""
out = self.vertex_forward(x)
out = self.propagate(x=out, csc=csc)
return out

def propagate(self, x, csc):
e_msg = self.scatter_to_edge(x, csc)
out = self.edge_forward(e_msg, csc)
out = self.scatter_to_vertex(out, csc)
return out

def scatter_to_edge(self, x, csc) -> Var:
out1 = ScatterToEdge(x, csc, "src")
out2 = ScatterToEdge(x, csc, "dst")
out = jt.contrib.concat([out1, out2], dim=1)
return out

def edge_forward(self, x, csc) -> Var:
# GAT_V2: Apply non-linear activation first, then multiply with attention vectors
# According to paper: e_ij = a^T * LeakyReLU(W * [h_i || h_j])
# GAT: e_ij = LeakyReLU(a^T * W * [h_i || h_j])
m = jt.nn.leaky_relu(x, scale=0.2) # Activation first (GAT_V2)
out = m @ self.edge_weight # Then multiply with attention vector (GAT_V2)
a = EdgeSoftmax(out, csc)
half_dim = int(jt.size(x, 1) / 2)
e_msg = x[:, 0:half_dim]
return e_msg * a

def scatter_to_vertex(self, edge, csc) -> Var:
out = ScatterToVertex(edge, csc, 'src')
return out

def vertex_forward(self, x: Var) -> Var:
# Keep same as GAT for fair comparison
out = x @ self.weight
out = jt.nn.relu(out)
return out


def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
5 changes: 5 additions & 0 deletions run_gat.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh
conda activate jittor
cd /mnt/d/Code/qianshi0310/jittorgeometric/merge0310/JittorGeometric
python examples/gat_example.py --dataset cora 2>&1 | tail -30
5 changes: 5 additions & 0 deletions run_gatv2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
cd /mnt/d/Code/qianshi0310/jittorgeometric/merge0310/JittorGeometric
source /home/liuyuan/miniconda3/etc/profile.d/conda.sh
conda activate jittor
python examples/gatv2_example.py --dataset cora