Source code for graphslim.models.base

from copy import deepcopy

import torch.nn as nn
import torch.optim as optim

from graphslim.utils import *


[docs] class BaseGNN(nn.Module): def __init__(self, nfeat, nhid, nclass, args, mode): super(BaseGNN, self).__init__() self.args = args self.with_bn = args.with_bn self.with_relu = True self.with_bias = True self.weight_decay = args.weight_decay self.lr = args.lr self.dropout = args.dropout self.alpha = args.alpha self.nlayers = args.nlayers self.ntrans = args.ntrans self.device = args.device self.layers = nn.ModuleList([]) self.loss = None if mode == 'eval': self.dropout = 0 self.weight_decay = 5e-4 if mode == 'attack': self.loss = F.nll_loss self.output = None self.best_model = None self.best_output = None self.adj_norm = None self.features = None self.multi_label = args.multi_label self.float_label = None # self.metric = accuracy if args.metric == 'accuracy' else f1_macro self.metric = args.metric
[docs] def initialize(self): for layer in self.layers: layer.reset_parameters() if self.with_bn: for bn in self.bns: bn.reset_parameters()
[docs] def forward(self, x, adj, output_layer_features=False): if isinstance(adj, list): for i, layer in enumerate(self.layers): x = layer(x, adj[i]) if i != self.nlayers - 1: x = self.bns[i](x) if self.with_bn else x x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) else: feat_list = [] for ix, layer in enumerate(self.layers): x = layer(x, adj) if ix != self.nlayers - 1: x = self.bns[ix](x) if self.with_bn else x if self.with_relu: x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) if output_layer_features and ix < self.nlayers: feat_list.append(x.reshape(-1, x.shape[-1])) x = x.view(-1, x.shape[-1]) if self.multi_label: return torch.sigmoid(x) if output_layer_features: return feat_list, F.log_softmax(x, dim=1) else: return F.log_softmax(x, dim=1)
[docs] def fit_with_val(self, data, train_iters=600, verbose=False, normadj=True, setting='trans', reduced=False, final_output=False, best_val=None, **kwargs): args=self.args self.initialize() # data for training if reduced: adj, features, labels, labels_val = to_tensor(data.adj_syn, data.feat_syn, data.labels_syn, label2=data.labels_val, device=self.device) elif setting == 'trans': adj, features, labels, labels_val = to_tensor(data.adj_full, data.feat_full, label=data.labels_train, label2=data.labels_val, device=self.device) else: adj, features, labels, labels_val = to_tensor(data.adj_train, data.feat_train, label=data.labels_train, label2=data.labels_val, device=self.device) if self.__class__.__name__ == 'GAT': # gat must use SparseTensor if len(adj.shape) == 3: adj = [normalize_adj_tensor(a.to_sparse(), sparse=True) for a in adj] else: if not is_sparse_tensor(adj): adj = adj.to_sparse() adj = normalize_adj_tensor(adj, sparse=True) # SparseTensor synthetic graph only used in graphsage, msgc and simgc elif self.__class__.__name__ == 'GraphSage' and args.method == 'msgc': adj = adj elif args.method == 'simgc': adj = normalize_adj_tensor(adj, sparse=True) else: adj = normalize_adj_tensor(adj, sparse=is_sparse_tensor(adj)) if self.loss is None: if args.method == 'geom' and args.soft_label: self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) elif data.nclass == 1: self.loss = torch.nn.BCELoss() elif len(labels.shape) == 2: if args.eval_loss=='MSE': self.loss = torch.nn.MSELoss() elif args.eval_loss=='KLD': self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) else: raise NotImplementedError self.weight_decay = args.eval_wd else: labels = to_tensor(label=labels, device=self.device) self.loss = F.nll_loss else: labels = to_tensor(label=labels, device=self.device) if verbose: print('=== training ===') if best_val is None: best_acc_val = 0 else: best_acc_val = best_val if setting == 'ind': feat_full, adj_full = data.feat_val, data.adj_val else: feat_full, adj_full = data.feat_full, data.adj_full feat_full, adj_full = to_tensor(feat_full, adj_full, device=self.device) if normadj: adj_full = normalize_adj_tensor(adj_full, sparse=is_sparse_tensor(adj_full)) optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.train() for i in range(train_iters): if i == train_iters // 2 and self.lr > 0.001: optimizer = optim.Adam(self.parameters(), lr=self.lr * 0.1, weight_decay=self.weight_decay) optimizer.zero_grad() output = self.forward(features, adj) loss_train = self.loss(output if output.shape[0] == labels.shape[0] else output[data.idx_train], labels) loss_train.backward() optimizer.step() if verbose and i % 100 == 0: print('Epoch {}, training loss: {}'.format(i, loss_train.item())) acc_train = accuracy(output if output.shape[0] == labels.shape[0] else output[data.idx_train], labels) print('Epoch {}, training acc: {}'.format(i, acc_train)) with torch.no_grad(): self.eval() output = self.forward(feat_full, adj_full) acc_val = self.metric(output if output.shape[0] == labels_val.shape[0] else output[data.idx_val], labels_val) if acc_val > best_acc_val: best_acc_val = acc_val # self.output = output weights = deepcopy(self.state_dict()) if final_output: return if verbose: print('=== picking the best model according to the performance on validation ===') try: self.load_state_dict(weights) except: pass return best_acc_val.item()
@torch.no_grad() def test(self, data, setting='trans', verbose=False): """Evaluate GCN performance on test set. Parameters ---------- idx_test : node testing indices """ self.eval() idx_test = data.idx_test labels_test = torch.LongTensor(data.labels_test).to(self.device) # whether condensed or not, use the raw graph to test if setting == 'ind': output = self.predict(data.feat_test, data.adj_test) loss_test = F.nll_loss(output, labels_test) acc_test = self.metric(output, labels_test) else: output = self.predict(data.feat_full, data.adj_full) loss_test = F.nll_loss(output[idx_test], labels_test) acc_test = self.metric(output[idx_test], labels_test) if verbose: print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item())) return acc_test.item() @torch.no_grad() def predict(self, features=None, adj=None, normadj=True, output_layer_features=False): self.eval() features, adj = to_tensor(features, adj, device=self.device) if normadj: adj = normalize_adj_tensor(adj, sparse=is_sparse_tensor(adj)) return self.forward(features, adj, output_layer_features=output_layer_features)