Source code for graphslim.models.sgformer
import math
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree
from graphslim.utils import *
from graphslim.dataset import *
from torch import optim
from copy import deepcopy
from torch.nn import NLLLoss
[docs]
def normalize_adj(adj):
"""
Normalize the adjacency matrix (sparse COO tensor).
Args:
adj (torch.sparse_coo_tensor): The sparse COO adjacency matrix.
Returns:
torch.sparse_coo_tensor: The normalized sparse COO adjacency matrix.
"""
# Extract indices and values from the sparse COO tensor
if is_sparse_tensor(adj):
row, col = adj.indices()
values = adj.values()
# Number of nodes
N = adj.size(0)
# Compute degree for normalization
d = degree(col, N).float()
d_norm_in = (1. / d[col]).sqrt()
d_norm_out = (1. / d[row]).sqrt()
# Normalize the values directly
normalized_values = values * d_norm_in * d_norm_out
normalized_values = torch.nan_to_num(normalized_values, nan=0.0, posinf=0.0, neginf=0.0)
# Create a new sparse COO tensor with normalized values
adj_normalized = torch.sparse_coo_tensor(torch.stack([row, col]), normalized_values, size=(N, N))
else:
N = adj.size(0)
# Compute degree for normalization
d = adj.sum(dim=1).float()
d_norm = torch.diag((1. / d).sqrt())
# Normalize the dense adjacency matrix
adj_normalized = d_norm @ adj @ d_norm
adj_normalized = torch.nan_to_num(adj_normalized, nan=0.0, posinf=0.0, neginf=0.0)
return adj_normalized
return adj_normalized
[docs]
class GraphConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, use_weight=True, use_init=False):
super(GraphConvLayer, self).__init__()
self.use_init = use_init
self.use_weight = use_weight
if self.use_init:
in_channels_ = 2 * in_channels
else:
in_channels_ = in_channels
self.W = nn.Linear(in_channels_, out_channels)
[docs]
def forward(self, x, adj, x0):
# N = x.shape[0]
# row, col = edge_index
# d = degree(col, N).float()
# d_norm_in = (1. / d[col]).sqrt()
# d_norm_out = (1. / d[row]).sqrt()
# value = torch.ones_like(row) * d_norm_in * d_norm_out
# value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
# adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
adj = normalize_adj(adj)
x = adj @ x
if self.use_init:
x = torch.cat([x, x0], 1)
x = self.W(x)
elif self.use_weight:
x = self.W(x)
return x
[docs]
class GraphConv(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers=2, dropout=0.5, use_bn=True, use_residual=True,
use_weight=True, use_init=False, use_act=True):
super(GraphConv, self).__init__()
self.convs = nn.ModuleList()
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(in_channels, hidden_channels))
self.bns = nn.ModuleList()
self.bns.append(nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers):
self.convs.append(
GraphConvLayer(hidden_channels, hidden_channels, use_weight, use_init))
self.bns.append(nn.BatchNorm1d(hidden_channels))
self.dropout = dropout
self.activation = F.relu
self.use_bn = use_bn
self.use_residual = use_residual
self.use_act = use_act
[docs]
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
for fc in self.fcs:
fc.reset_parameters()
[docs]
def forward(self, x, edge_index):
layer_ = []
x = self.fcs[0](x)
if self.use_bn:
x = self.bns[0](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
layer_.append(x)
for i, conv in enumerate(self.convs):
x = conv(x, edge_index, layer_[0])
if self.use_bn:
x = self.bns[i + 1](x)
if self.use_act:
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
if self.use_residual:
x = x + layer_[-1]
return x
[docs]
class TransConvLayer(nn.Module):
'''
transformer with fast attention
'''
def __init__(self, in_channels,
out_channels,
num_heads,
use_weight=True):
super().__init__()
self.Wk = nn.Linear(in_channels, out_channels * num_heads)
self.Wq = nn.Linear(in_channels, out_channels * num_heads)
if use_weight:
self.Wv = nn.Linear(in_channels, out_channels * num_heads)
self.out_channels = out_channels
self.num_heads = num_heads
self.use_weight = use_weight
[docs]
def reset_parameters(self):
self.Wk.reset_parameters()
self.Wq.reset_parameters()
if self.use_weight:
self.Wv.reset_parameters()
[docs]
def forward(self, query_input, source_input, output_attn=False):
# feature transformation
qs = self.Wq(query_input).reshape(-1, self.num_heads, self.out_channels)
ks = self.Wk(source_input).reshape(-1, self.num_heads, self.out_channels)
if self.use_weight:
vs = self.Wv(source_input).reshape(-1, self.num_heads, self.out_channels)
else:
vs = source_input.reshape(-1, 1, self.out_channels)
# normalize input
qs = qs / torch.norm(qs, p=2) # [N, H, M]
ks = ks / torch.norm(ks, p=2) # [L, H, M]
N = qs.shape[0]
# numerator
kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
attention_num += N * vs
# denominator
all_ones = torch.ones([ks.shape[0]]).to(ks.device)
ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H]
# attentive aggregated results
attention_normalizer = torch.unsqueeze(
attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1]
attention_normalizer += torch.ones_like(attention_normalizer) * N
attn_output = attention_num / attention_normalizer # [N, H, D]
# compute attention for visualization if needed
if output_attn:
attention = torch.einsum("nhm,lhm->nlh", qs, ks).mean(dim=-1) # [N, N]
normalizer = attention_normalizer.squeeze(dim=-1).mean(dim=-1, keepdims=True) # [N,1]
attention = attention / normalizer
final_output = attn_output.mean(dim=1)
if output_attn:
return final_output, attention
else:
return final_output
[docs]
class TransConv(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers=2, num_heads=1,
dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_act=True):
super().__init__()
self.convs = nn.ModuleList()
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(in_channels, hidden_channels))
self.bns = nn.ModuleList()
self.bns.append(nn.LayerNorm(hidden_channels))
for i in range(num_layers):
self.convs.append(
TransConvLayer(hidden_channels, hidden_channels, num_heads=num_heads, use_weight=use_weight))
self.bns.append(nn.LayerNorm(hidden_channels))
self.dropout = dropout
self.activation = F.relu
self.use_bn = use_bn
self.use_residual = use_residual
self.use_act = use_act
[docs]
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
for fc in self.fcs:
fc.reset_parameters()
[docs]
def forward(self, x):
layer_ = []
# input MLP layer
x = self.fcs[0](x)
if self.use_bn:
x = self.bns[0](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# store as residual link
layer_.append(x)
for i, conv in enumerate(self.convs):
# graph convolution with full attention aggregation
x = conv(x, x)
if self.use_residual:
x = (x + layer_[i]) / 2.
if self.use_bn:
x = self.bns[i + 1](x)
if self.use_act:
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
layer_.append(x)
return x
[docs]
def get_attentions(self, x):
layer_, attentions = [], []
x = self.fcs[0](x)
if self.use_bn:
x = self.bns[0](x)
x = self.activation(x)
layer_.append(x)
for i, conv in enumerate(self.convs):
x, attn = conv(x, x, output_attn=True)
attentions.append(attn)
if self.use_residual:
x = (x + layer_[i]) / 2.
if self.use_bn:
x = self.bns[i + 1](x)
if self.use_act:
x = self.activation(x)
layer_.append(x)
return torch.stack(attentions, dim=0) # [layer num, N, N]
[docs]
class SGFormer(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, args, mode='eval',
trans_num_layers=3, trans_num_heads=1, trans_dropout=0.2, trans_weight_decay=0.001, trans_use_bn=False,
trans_use_residual=True,
trans_use_weight=False, trans_use_act=False,
gnn_num_layers=2, gnn_dropout=0.5, gnn_use_weight=False, gnn_weight_decay=5e-4, gnn_use_init=False,
gnn_use_bn=False,
gnn_use_residual=False, gnn_use_act=False,
use_graph=True, graph_weight=0.8, aggregate='add'):
super().__init__()
self.device = args.device
self.lr = args.lr
if hasattr(args, 'train_weight_decay'):
self.trans_weight_decay = args.trans_weight_decay
else:
self.trans_weight_decay = trans_weight_decay
if hasattr(args, 'trans_dropout'):
self.trans_dropout = args.trans_dropout
else:
self.trans_dropout = trans_dropout
if hasattr(args, 'trans_num_layers'):
self.trans_num_layers = args.trans_num_layers
else:
self.trans_num_layers = trans_num_layers
self.gnn_weight_decay = gnn_weight_decay
self.trans_conv = TransConv(in_channels, hidden_channels, self.trans_num_layers, trans_num_heads,
self.trans_dropout,
trans_use_bn, trans_use_residual, trans_use_weight, trans_use_act)
self.graph_conv = GraphConv(in_channels, hidden_channels, gnn_num_layers, gnn_dropout, gnn_use_bn,
gnn_use_residual, gnn_use_weight, gnn_use_init, gnn_use_act)
self.use_graph = use_graph
self.graph_weight = graph_weight
self.aggregate = aggregate
if aggregate == 'add':
self.fc = nn.Linear(hidden_channels, out_channels)
elif aggregate == 'cat':
self.fc = nn.Linear(2 * hidden_channels, out_channels)
else:
raise ValueError(f'Invalid aggregate type:{aggregate}')
self.params1 = list(self.trans_conv.parameters())
self.params2 = list(self.graph_conv.parameters()) if self.graph_conv is not None else []
self.params2.extend(list(self.fc.parameters()))
[docs]
def forward(self, x, edge_index):
x1 = self.trans_conv(x)
if self.use_graph:
x2 = self.graph_conv(x, edge_index)
if self.aggregate == 'add':
x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
else:
x = torch.cat((x1, x2), dim=1)
else:
x = x1
x = self.fc(x)
return F.log_softmax(x, dim=1)
[docs]
def get_attentions(self, x):
attns = self.trans_conv.get_attentions(x) # [layer num, N, N]
return attns
[docs]
def reset_parameters(self):
self.trans_conv.reset_parameters()
if self.use_graph:
self.graph_conv.reset_parameters()
[docs]
def fit_with_val(self, data, train_iters=600, verbose=False,
normadj=True, setting='trans', reduced=False, reindex=False,
**kwargs):
self.reset_parameters()
# data for training
if reduced:
adj, features, labels, labels_val = to_tensor(data.adj_syn, data.feat_syn, data.labels_syn,
label2=data.labels_val,
device=self.device)
elif setting == 'trans':
adj, features, labels, labels_val = to_tensor(data.adj_full, data.feat_full, label=data.labels_train,
label2=data.labels_val, device=self.device)
else:
adj, features, labels, labels_val = to_tensor(data.adj_train, data.feat_train, label=data.labels_train,
label2=data.labels_val, device=self.device)
# edge_index = torch.stack([adj.indices()[0], adj.indices()[1]], dim=0).to(self.device)
labels = to_tensor(label=labels, device=self.device)
self.loss = F.nll_loss
# elif len(data.labels_full.shape) > 1:
# self.multi_label = True
# self.loss = torch.nn.BCELoss()
# else:
# self.multi_label = False
# self.loss = F.nll_loss
if reduced or setting == 'ind':
reindex = True
if verbose:
print('=== training ===')
best_acc_val = 0
if setting == 'ind':
feat_full, adj_full = data.feat_val, data.adj_val
else:
feat_full, adj_full = data.feat_full, data.adj_full
feat_full, adj_full = to_tensor(feat_full, adj_full, device=self.device)
optimizer = torch.optim.Adam([
{'params': self.params1, 'weight_decay': self.trans_weight_decay},
{'params': self.params2, 'weight_decay': self.gnn_weight_decay}
], lr=self.lr)
self.train()
for i in range(train_iters):
optimizer.zero_grad()
output = self.forward(features, adj)
loss_train = self.loss(output if reindex else output[data.idx_train], labels)
loss_train.backward()
optimizer.step()
if verbose and i % 100 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
acc_train = accuracy(output if reindex else output[data.idx_train], labels)
print('Epoch {}, training acc: {}'.format(i, acc_train))
with torch.no_grad():
self.eval()
output = self.forward(feat_full, adj_full)
if setting == 'ind':
# loss_val = F.nll_loss(output, labels_val)
acc_val = accuracy(output, labels_val)
else:
# loss_val = F.nll_loss(output[data.idx_val], labels_val)
acc_val = accuracy(output[data.idx_val], labels_val)
if acc_val > best_acc_val:
best_acc_val = acc_val
# self.output = output
weights = deepcopy(self.state_dict())
if verbose:
print('=== picking the best model according to the performance on validation ===')
self.load_state_dict(weights)
return best_acc_val
@torch.no_grad()
def test(self, data, setting='trans', verbose=False):
"""Evaluate GCN performance on test set.
Parameters
----------
idx_test :
node testing indices
"""
self.eval()
idx_test = data.idx_test
labels_test = torch.LongTensor(data.labels_test).to(self.device)
# whether condensed or not, use the raw graph to test
if setting == 'ind':
output = self.predict(data.feat_test, data.adj_test)
loss_test = F.nll_loss(output, labels_test)
acc_test = accuracy(output, labels_test)
else:
output = self.predict(data.feat_full, data.adj_full)
loss_test = F.nll_loss(output[idx_test], labels_test)
acc_test = accuracy(output[idx_test], labels_test)
if verbose:
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
@torch.no_grad()
def predict(self, features=None, adj=None, normadj=True, output_layer_features=False):
self.eval()
features, adj = to_tensor(features, adj, device=self.device)
return self.forward(features, adj)