Source code for graphslim.models.graphsage

import torch
import torch.nn as nn

from graphslim.models.base import BaseGNN
from graphslim.models.layers import SageConvolution


[docs] class GraphSage(BaseGNN): def __init__(self, nfeat, nhid, nclass, args, mode='train'): super(GraphSage, self).__init__(nfeat, nhid, nclass, args, mode) with_bn = self.with_bn if self.nlayers == 1: self.layers.append(SageConvolution(nfeat, nclass)) else: if with_bn: self.bns = torch.nn.ModuleList() self.bns.append(nn.BatchNorm1d(nhid)) self.layers.append(SageConvolution(nfeat, nhid)) for i in range(self.nlayers - 2): self.layers.append(SageConvolution(nhid, nhid)) if with_bn: self.bns.append(nn.BatchNorm1d(nhid)) self.layers.append(SageConvolution(nhid, nclass))
# def fit_with_val(self, data, train_iters=200, verbose=False, # normadj=True, setting='trans', reduced=False, reindex=False, # **kwargs): # # self.initialize() # # data for training # if reduced: # adj, features, labels, labels_val = to_tensor(data.adj_syn, data.feat_syn, label=data.labels_syn, # label2=data.labels_val, # device=self.device) # elif setting == 'trans': # adj, features, labels, labels_val = to_tensor(data.adj_full, data.feat_full, data.labels_train, # data.labels_val, device=self.device) # else: # adj, features, labels, labels_val = to_tensor(data.adj_train, data.feat_train, data.labels_train, # data.labels_val, device=self.device) # if normadj: # adj = normalize_adj_tensor(adj, sparse=is_sparse_tensor(adj)) # # if len(data.labels_full.shape) > 1: # self.multi_label = True # self.loss = torch.nn.BCELoss() # elif len(labels.shape) > 1: # for GCSNTK, use MSE for training # # print("MSE loss") # self.float_label = True # self.loss = torch.nn.MSELoss() # else: # self.multi_label = False # self.loss = F.nll_loss # # if reduced or setting == 'ind': # reindex = True # # if verbose: # print('=== training GNN model ===') # optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) # # best_acc_val = 0 # # data for validation # if setting == 'ind': # feat_full, adj_full = data.feat_val, data.adj_val # else: # feat_full, adj_full = data.feat_full, data.adj_full # feat_full, adj_full = to_tensor(feat_full, adj_full, device=self.device) # if normadj: # adj_full = normalize_adj_tensor(adj_full, sparse=is_sparse_tensor(adj_full)) # # if self.args.method not in ['msgc']: # # adj -> adj (SparseTensor) # # msgc cannot use sampling # adj = dense2sparsetensor(adj) # if adj.density() > 0.5: # if the weighted graph is too dense, we need a larger neighborhood size # sizes = [30, 20] # else: # sizes = [5, 5] # if reduced: # node_idx = torch.arange(data.labels_syn.size(0), device=self.device) # elif setting == 'ind': # node_idx = torch.arange(data.labels_train.size(0), device=self.device) # else: # node_idx = torch.arange(data.labels_full.size(0), device=self.device) # # train_loader = NeighborSampler(adj, # node_idx=node_idx, # sizes=sizes, batch_size=len(node_idx), # num_workers=8, return_e_id=False, # num_nodes=adj.size(0), # shuffle=True) # # best_acc_val = 0 # self.train() # # for i in range(train_iters): # if i == train_iters // 2: # lr = self.lr * 0.1 # optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay) # # if self.args.method == 'msgc': # optimizer.zero_grad() # output = self.forward(features, adj) # loss_train = self.loss(output if reindex else output[data.idx_train], labels) # # loss_train.backward() # optimizer.step() # else: # for batch_size, n_id, adjs in train_loader: # adjs = [adj[0].to(self.device) for adj in adjs] # optimizer.zero_grad() # out = self.forward(features[n_id], adjs) # loss_train = self.loss(out, labels[n_id[:batch_size]]) # loss_train.backward() # optimizer.step() # # if verbose and i + 1 % 100 == 0: # print('Epoch {}, training loss: {}'.format(i, loss_train.item())) # # with torch.no_grad(): # self.eval() # output = self.forward(feat_full, adj_full) # if setting == 'ind': # # loss_val = F.nll_loss(output, labels_val) # acc_val = accuracy(output, labels_val) # else: # # loss_val = F.nll_loss(output[data.idx_val], labels_val) # acc_val = accuracy(output[data.idx_val], labels_val) # # if acc_val > best_acc_val: # best_acc_val = acc_val # self.output = output # weights = deepcopy(self.state_dict()) # # if verbose: # print('=== picking the best model according to the performance on validation ===') # self.load_state_dict(weights) # return best_acc_val