Source code for graphslim.models.appnp

"""multiple transformaiton and multiple propagation"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_sparse
from torch_sparse import SparseTensor

from graphslim.models.base import BaseGNN
from graphslim.models.layers import MyLinear


[docs] class APPNP(BaseGNN): def __init__(self, nfeat, nhid, nclass, args, mode='train'): super(APPNP, self).__init__(nfeat, nhid, nclass, args, mode) # if mode in ['eval']: # self.ntrans = 1 if self.ntrans == 1: self.layers.append(MyLinear(nfeat, nclass)) else: self.layers.append(MyLinear(nfeat, nhid)) if self.with_bn: self.bns = torch.nn.ModuleList() self.bns.append(nn.BatchNorm1d(nhid)) for i in range(self.ntrans - 2): if self.with_bn: self.bns.append(nn.BatchNorm1d(nhid)) self.layers.append(MyLinear(nhid, nhid)) self.layers.append(MyLinear(nhid, nclass)) self.sparse_dropout = SparseDropout(dprob=0) activation_functions = { 'sigmoid': torch.sigmoid, 'tanh': torch.tanh, 'relu': F.relu, 'linear': lambda x: x, 'softplus': F.softplus, 'leakyrelu': F.leaky_relu, 'relu6': F.relu6, 'elu': F.elu } self.activation = activation_functions.get(args.activation)
[docs] def forward(self, x, adj, output_layer_features=False): for ix, layer in enumerate(self.layers): x = layer(x) if ix != len(self.layers) - 1: x = self.bns[ix](x) if self.with_bn else x x = self.activation(x) x = F.dropout(x, self.dropout, training=self.training) h = x # here nlayers means K for i in range(self.nlayers): # adj_drop = self.sparse_dropout(adj, training=self.training) adj_drop = adj if isinstance(adj_drop, SparseTensor): x = torch_sparse.matmul(adj_drop, x) else: x = adj_drop @ x x = x * (1 - self.alpha) x = x + self.alpha * h x = x.view(-1, x.shape[-1]) return F.log_softmax(x, dim=1)
[docs] def forward_sampler(self, x, adjs): for ix, layer in enumerate(self.layers): x = layer(x) if ix != len(self.layers) - 1: x = self.bns[ix](x) if self.with_bn else x x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) h = x for ix, (adj, _, size) in enumerate(adjs): adj_drop = adj h = h[: size[1]] x = torch_sparse.matmul(adj_drop, x) x = x * (1 - self.alpha) x = x + self.alpha * h if self.multi_label: return torch.sigmoid(x) else: return F.log_softmax(x, dim=1)
[docs] class SparseDropout(torch.nn.Module): def __init__(self, dprob=0.5): super(SparseDropout, self).__init__() self.kprob = 1 - dprob
[docs] def forward(self, x, training): if training: mask = ((torch.rand(x._values().size()) + (self.kprob)).floor()).type(torch.bool) rc = x._indices()[:, mask] val = x._values()[mask] * (1.0 / self.kprob) return torch.sparse.FloatTensor(rc, val, x.size()) else: return x
# class APPNPRich(BaseGNN): # ''' # two transformation layer # ''' # # def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, alpha=0.1, # activation="relu", with_relu=True, with_bias=True, with_bn=False, device=None): # # super(APPNPRich, self).__init__(nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4, # with_relu=True, with_bias=True, with_bn=False, device=device) # # self.alpha = alpha # activation_functions = { # 'sigmoid': F.sigmoid, # 'tanh': F.tanh, # 'relu': F.relu, # 'linear': lambda x: x, # 'softplus': F.softplus, # 'leakyrelu': F.leaky_relu, # 'relu6': F.relu6, # 'elu': F.elu # } # self.activation = activation_functions.get(activation) # # if with_bn: # self.bns = torch.nn.ModuleList() # self.bns.append(nn.BatchNorm1d(nhid)) # # self.layers = nn.ModuleList([]) # self.layers.append(MyLinear(nfeat, nhid)) # self.layers.append(MyLinear(nhid, nclass)) # # # if nlayers == 1: # # self.layers.append(nn.Linear(nfeat, nclass)) # # else: # # self.layers.append(nn.Linear(nfeat, nhid)) # # for i in range(nlayers-2): # # self.layers.append(nn.Linear(nhid, nhid)) # # self.layers.append(nn.Linear(nhid, nclass)) # # self.nlayers = nlayers # self.dropout = dropout # self.lr = lr # # self.sparse_dropout = SparseDropout(dprob=0) # # def forward(self, x, adj, output_layer_features=None): # for ix, layer in enumerate(self.layers): # x = layer(x) # if ix != len(self.layers) - 1: # x = self.bns[ix](x) if self.with_bn else x # # x = F.relu(x) # x = self.activation(x) # x = F.dropout(x, self.dropout, training=self.training) # # h = x # # here nlayers means K # for i in range(self.nlayers): # # adj_drop = self.sparse_dropout(adj, training=self.training) # adj_drop = adj # if isinstance(adj_drop, SparseTensor): # x = torch_sparse.matmul(adj_drop, x) # else: # x = torch.spmm(adj_drop, x) # x = x * (1 - self.alpha) # x = x + self.alpha * h # # if self.multi_label: # return torch.sigmoid(x) # else: # return F.log_softmax(x, dim=1) # # def forward_sampler(self, x, adjs): # for ix, layer in enumerate(self.layers): # x = layer(x) # if ix != len(self.layers) - 1: # x = self.bns[ix](x) if self.with_bn else x # x = self.activation(x) # x = F.dropout(x, self.dropout, training=self.training) # # h = x # for ix, (adj, _, size) in enumerate(adjs): # adj_drop = adj # h = h[: size[1]] # x = torch_sparse.matmul(adj_drop, x) # x = x * (1 - self.alpha) # x = x + self.alpha * h # # if self.multi_label: # return torch.sigmoid(x) # else: # return F.log_softmax(x, dim=1)