diff --git a/cogdl/models/nn/pair_norm.py b/cogdl/models/nn/pair_norm.py new file mode 100644 index 00000000..7634d3df --- /dev/null +++ b/cogdl/models/nn/pair_norm.py @@ -0,0 +1,358 @@ +""" +Chuan Wen' implementation. +""" +import numpy as np +import scipy.sparse as sp +import torch +import torch.nn as nn +from torch_sparse import spmm # require the newest torch_sprase +import torch.nn.functional as F + +from .. import BaseModel, register_model + +""" + helpers +""" +from torch_scatter import scatter_max, scatter_add + + +def softmax(src, index, num_nodes=None): + """ + sparse softmax + """ + num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + return out + + +def normalize_adj_row(adj): + """Row-normalize sparse matrix""" + rowsum = np.array(adj.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(adj) + return mx + + +def to_torch_sparse(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + sparse_mx = sparse_mx.tocoo().astype(np.float32) + indices = torch.from_numpy( + np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + values = torch.from_numpy(sparse_mx.data) + shape = torch.Size(sparse_mx.shape) + return torch.sparse.FloatTensor(indices, values, shape) + + +class GraphConv(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super(GraphConv, self).__init__() + self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) + if bias: + self.bias = nn.Parameter(torch.FloatTensor(1, out_features)) + + self.in_features = in_features + self.out_features = out_features + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / np.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, input, adj): + h = torch.mm(input, self.weight) + output = torch.spmm(adj, h) + if self.bias is not None: + return output + self.bias + return output + + def __repr__(self): + return self.__class__.__name__ + "({}->{})".format( + self.in_features, self.out_features) + + +class GraphAttConv(nn.Module): + def __init__(self, in_features, out_features, heads, dropout): + super(GraphAttConv, self).__init__() + assert out_features % heads == 0 + out_perhead = out_features // heads + + self.graph_atts = nn.ModuleList([GraphAttConvOneHead( + in_features, out_perhead, dropout=dropout) for _ in range(heads)]) + + self.in_features = in_features + self.out_perhead = out_perhead + self.heads = heads + + def forward(self, input, adj): + output = torch.cat([att(input, adj) for att in self.graph_atts], dim=1) + # notice that original GAT use elu as activation func. + return output + + def __repr__(self): + return self.__class__.__name__ + "({}->[{}x{}])".format( + self.in_features, self.heads, self.out_perhead) + + +class GraphAttConvOneHead(nn.Module): + """ + Sparse version GAT layer, single head + """ + + def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2): + super(GraphAttConvOneHead, self).__init__() + self.weight = nn.Parameter(torch.zeros(size=(in_features, out_features))) + self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) + # init + nn.init.xavier_normal_(self.weight.data, gain=nn.init.calculate_gain('relu')) # look at here + nn.init.xavier_normal_(self.a.data, gain=nn.init.calculate_gain('relu')) + + self.dropout = nn.Dropout(dropout) + self.leakyrelu = nn.LeakyReLU(alpha) + + def forward(self, input, adj): + # edge = adj._indices() + edge = torch.nonzero(adj).T + h = torch.mm(input, self.weight) + # Self-attention on the nodes - Shared attention mechanism + edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() # edge_h: 2*D x E + # do softmax for each row, this need index of each row, and for each row do softmax over it + alpha = self.leakyrelu(self.a.mm(edge_h).squeeze()) # E + n = len(input) + alpha = softmax(alpha, edge[0], n) + output = spmm(edge, self.dropout(alpha), n, n, h) # h_prime: N x out + # output = spmm(edge, self.dropout(alpha), n, n, self.dropout(h)) # h_prime: N x out + return output + + +class PairNormLayer(nn.Module): + def __init__(self, mode='PN', scale=1): + """ + mode: + 'None' : No normalization + 'PN' : Original version + 'PN-SI' : Scale-Individually version + 'PN-SCS' : Scale-and-Center-Simultaneously version + + ('SCS'-mode is not in the paper but we found it works well in practice, + especially for GCN and GAT.) + PairNorm is typically used after each graph convolution operation. + """ + assert mode in ['None', 'PN', 'PN-SI', 'PN-SCS'] + super(PairNormLayer, self).__init__() + self.mode = mode + self.scale = scale + + # Scale can be set based on origina data, and also the current feature lengths. + # We leave the experiments to future. A good pool we used for choosing scale: + # [0.1, 1, 10, 50, 100] + + def forward(self, x): + if self.mode == 'None': + return x + + col_mean = x.mean(dim=0) + if self.mode == 'PN': + x = x - col_mean + rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt() + x = self.scale * x / rownorm_mean + + if self.mode == 'PN-SI': + x = x - col_mean + rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt() + x = self.scale * x / rownorm_individual + + if self.mode == 'PN-SCS': + rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt() + x = self.scale * x / rownorm_individual - col_mean + + return x + + +class SGC(nn.Module): + # for SGC we use data without normalization + def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, norm_mode='None', norm_scale=10, **kwargs): + super(SGC, self).__init__() + self.linear = torch.nn.Linear(nfeat, nclass) + self.norm = PairNormLayer(norm_mode, norm_scale) + self.dropout = nn.Dropout(p=dropout) + self.nlayer = nlayer + + def forward(self, x, adj): + x = self.norm(x) + for _ in range(self.nlayer): + x = adj.mm(x) + x = self.norm(x) + x = self.dropout(x) + x = self.linear(x) + return x + + +class GCN(nn.Module): + def __init__(self, nfeat, nhid, nclass, dropout, + norm_mode='None', norm_scale=1, **kwargs): + super(GCN, self).__init__() + self.gc1 = GraphConv(nfeat, nhid) + self.gc2 = GraphConv(nhid, nclass) + + self.dropout = nn.Dropout(p=dropout) + self.relu = nn.ReLU(True) + self.norm = PairNormLayer(norm_mode, norm_scale) + + def forward(self, x, adj): + x = self.dropout(x) + x = self.gc1(x, adj) + x = self.norm(x) + x = self.relu(x) + x = self.dropout(x) + x = self.gc2(x, adj) + return x + + +class GAT(nn.Module): + def __init__(self, nfeat, nhid, nclass, dropout, nhead, + norm_mode='None', norm_scale=1, **kwargs): + super(GAT, self).__init__() + alpha_droprate = dropout + self.gac1 = GraphAttConv(nfeat, nhid, nhead, alpha_droprate) + self.gac2 = GraphAttConv(nhid, nclass, 1, alpha_droprate) + + self.dropout = nn.Dropout(p=dropout) + self.relu = nn.ELU(True) + self.norm = PairNormLayer(norm_mode, norm_scale) + + def forward(self, x, adj): + x = self.dropout(x) # ? + x = self.gac1(x, adj) + x = self.norm(x) + x = self.relu(x) + x = self.dropout(x) + x = self.gac2(x, adj) + return x + + +class DeepGCN(nn.Module): + def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0, + norm_mode='None', norm_scale=1, **kwargs): + super(DeepGCN, self).__init__() + assert nlayer >= 1 + self.hidden_layers = nn.ModuleList([ + GraphConv(nfeat if i == 0 else nhid, nhid) + for i in range(nlayer - 1) + ]) + self.out_layer = GraphConv(nfeat if nlayer == 1 else nhid, nclass) + + self.dropout = nn.Dropout(p=dropout) + self.dropout_rate = dropout + self.relu = nn.ReLU(True) + self.norm = PairNormLayer(norm_mode, norm_scale) + self.skip = residual + + def forward(self, x, adj): + x_old = 0 + for i, layer in enumerate(self.hidden_layers): + x = self.dropout(x) + x = layer(x, adj) + x = self.norm(x) + x = self.relu(x) + if self.skip > 0 and i % self.skip == 0: + x = x + x_old + x_old = x + + x = self.dropout(x) + x = self.out_layer(x, adj) + return x + + +class DeepGAT(nn.Module): + def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0, nhead=1, + norm_mode='None', norm_scale=1, **kwargs): + super(DeepGAT, self).__init__() + assert nlayer >= 1 + alpha_droprate = dropout + self.hidden_layers = nn.ModuleList([ + GraphAttConv(nfeat if i == 0 else nhid, nhid, nhead, alpha_droprate) + for i in range(nlayer - 1) + ]) + self.out_layer = GraphAttConv(nfeat if nlayer == 1 else nhid, nclass, 1, alpha_droprate) + + self.dropout = nn.Dropout(p=dropout) + self.relu = nn.ELU(True) + self.norm = PairNormLayer(norm_mode, norm_scale) + self.skip = residual + + def forward(self, x, adj): + x_old = 0 + for i, layer in enumerate(self.hidden_layers): + x = self.dropout(x) + x = layer(x, adj) + x = self.norm(x) + x = self.relu(x) + if self.skip > 0 and i % self.skip == 0: + x = x + x_old + x_old = x + + x = self.dropout(x) + x = self.out_layer(x, adj) + return x + + +BASE_MODELS = {'SGC': SGC, 'GCN': GCN, 'GAT': GAT, 'DeepGCN': DeepGCN, 'DeepGAT':DeepGAT } + + +@register_model('chuan-pairnorm') +class PairNorm(BaseModel): + @staticmethod + def add_args(parser): + """ + ignore the missing_rate, lr, weight_decay and epochs, because they have been included in NodeClassification() and get_parser() + and ignore the no_fea_norm because feature normalization is necessary in default in CogDL + """ + parser.add_argument('--base_model', type=str, + default='GCN', choices=['SGC', 'GCN', 'GAT', 'DeepGCN', 'DeepGAT'], help='{SGC, DeepGCN, DeepGAT}') + parser.add_argument('--hid', type=int, + default=64, help='Number of hidden units.') + parser.add_argument('--nhead', type=int, default=1, + help='Number of head attentions.') + parser.add_argument('--dropout', type=float, + default=0.6, help='Dropout rate.') + # for PairNorm + parser.add_argument('--nlayer', type=int, default=2, + help='Number of layers, works for Deep model.') + parser.add_argument('--residual', type=int, + default=0, help='Residual connection') + parser.add_argument('--norm_mode', type=str, default='None', + help='Mode for PairNorm, {None, PN, PN-SI, PN-SCS}') + parser.add_argument('--norm_scale', type=float, + default=1.0, help='Row-normalization scale') + + @classmethod + def build_model_from_args(cls, args): + """Build a new model instance.""" + return cls(base_model=args.base_model, nfeat=args.num_features, nhid=args.hid, nclass=args.num_classes, + dropout=args.dropout, nhead=args.nhead, nlayer=args.nlayer, norm_mode=args.norm_mode, + norm_scale=args.norm_scale, residual=args.residual) + + def __init__(self, base_model, nfeat, nhid, nclass, dropout, nhead, nlayer, norm_mode, norm_scale, residual): + super(PairNorm, self).__init__() + self.net = BASE_MODELS[base_model](nfeat, nhid, nclass, dropout=dropout, nhead=nhead, nlayer=nlayer, + norm_mode=norm_mode, norm_scale=norm_scale, residual=residual) + + def forward(self, x, adj): + output = self.net(x, adj) + return F.log_softmax(output, dim=-1) + + def loss(self, data): + return F.nll_loss( + self.forward(data.x, data.adj)[data.train_mask], + data.y[data.train_mask], + ) + + def predict(self, data): + return self.forward(data.x, data.adj) + diff --git a/cogdl/tasks/node_classification.py b/cogdl/tasks/node_classification.py index 0cd65a8a..6378a315 100644 --- a/cogdl/tasks/node_classification.py +++ b/cogdl/tasks/node_classification.py @@ -89,9 +89,9 @@ def __init__( self.device = args.device_id[0] if not args.cpu else "cpu" dataset = build_dataset(args) if dataset is None else dataset if args.missing_rate >= 0: - if args.model == 'sgcpn': + if args.model == 'sgcpn' or args.model == 'chuan-pairnorm': assert args.dataset in ['cora', 'citeseer', 'pubmed'] - dataset.data = preprocess_data_sgcpn(dataset.data, normalize_feature=True, missing_rate=0) + dataset.data = preprocess_data_sgcpn(dataset.data, normalize_feature=True, missing_rate=args.missing_rate) adj_slice = torch.tensor(dataset.data.adj.size()) adj_slice[0] = 0 dataset.slices['adj'] = adj_slice diff --git a/match.yml b/match.yml index 2c97051c..7e9416eb 100644 --- a/match.yml +++ b/match.yml @@ -11,6 +11,7 @@ node_classification: - mixhop - graphsage - pairnorm + - chuan-pairnorm - unet - gcnii - chebyshev