Source code for graphslim.condensation.gcsntk

import time

import scipy.sparse as sp
import torch
from torch import nn
from torch.nn import functional as F

from graphslim.condensation.utils import normalize_data, GCF
from graphslim.condensation.utils import sub_E, update_E
from graphslim.condensation.gcond_base import GCondBase
from graphslim.dataset import LargeDataLoader
from graphslim.dataset.utils import save_reduced
from graphslim.evaluation.utils import verbose_time_memory
from graphslim.models import StructureBasedNeuralTangentKernel, KernelRidgeRegression
from tqdm import trange
from graphslim.utils import seed_everything


[docs] class GCSNTK(GCondBase): """ "GFast Graph Conensation with Structure-based Neural Tangent Kernel" https://arxiv.org/pdf/2310.11046 """ def __init__(self, setting, data, args, **kwargs): super(GCSNTK, self).__init__(setting, data, args, **kwargs) self.k = args.k self.K = args.K self.ridge = args.ridge self.L = args.L self.scale = args.scale
[docs] def train(self, KRR, G_t, G_s, y_t, y_s, E_t, E_s, loss_fn, optimizer, accumulate_steps=None, i=None, TRAIN_K=None): pred, acc = KRR.forward(G_t, G_s, y_t, y_s, E_t, E_s) pred = pred.to(torch.float32) y_t = y_t.to(torch.float32) loss = loss_fn(pred, y_t) loss = loss.to(torch.float32) if accumulate_steps is None: # with torch.autograd.detect_anomaly(): optimizer.zero_grad() loss.backward() optimizer.step() else: loss = loss / accumulate_steps loss.backward() if (i + 1) % accumulate_steps == 0: optimizer.step() optimizer.zero_grad() elif i == TRAIN_K - 1: optimizer.step() optimizer.zero_grad() loss = loss.item() # print(f"Training loss: {loss:>7f} Training Acc: {acc:>7f}", end=' ') return G_s, y_s, loss, acc * 100
[docs] def test(self, KRR, G_t, G_s, y_t, y_s, E_t, E_s, loss_fn): size = len(y_t) test_loss, correct = 0, 0 with torch.no_grad(): pred, _ = KRR.forward(G_t, G_s, y_t, y_s, E_t, E_s) test_loss += loss_fn(pred, y_t).item() correct += (pred.argmax(1) == y_t.argmax(1)).type(torch.float).sum().item() correct /= size print(f"Val Acc: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}", end='\n') return test_loss, correct * 100
[docs] @verbose_time_memory def reduce(self, data, verbose=True): args = self.args if args.dataset in ['flickr', 'reddit', 'ogbn-arxiv']: train_loader = LargeDataLoader(name=self.args.dataset, split='train', batch_size=self.args.batch_size, split_method='kmeans') TRAIN_K, n_train, n_class, dim, n = train_loader.properties() train_loader.split_batch() else: edge_index = data.edge_index n_class = len(torch.unique(data.y)) n, dim = data.x.shape adj = sp.coo_matrix((torch.ones(data.edge_index.shape[1]), edge_index), shape=(n, n)).toarray() adj = torch.tensor(adj) adj = adj + torch.eye(adj.shape[0]) idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test y_train, y_val, y_test = data.labels_train, data.labels_val, data.labels_test y_one_hot = F.one_hot(data.y, n_class) y_train_one_hot = y_one_hot[data.train_mask] n_train = len(y_train) Cond_size = round(n_train * self.args.reduction_rate) # fixed initialization by gaussian x_s = torch.rand(round(Cond_size), dim, device=args.device) y_s = torch.rand(round(Cond_size), n_class, device=args.device) x_s.requires_grad = True y_s.requires_grad = True idx_s = torch.tensor(range(Cond_size)) optimizer = torch.optim.Adam([x_s, y_s], lr=self.args.lr) SNTK = StructureBasedNeuralTangentKernel(K=self.K, L=self.L, scale=self.scale).to(self.device) ridge = torch.tensor(self.ridge).to(self.device) KRR = KernelRidgeRegression(SNTK.nodes_gram, ridge).to(self.device) MSEloss = nn.MSELoss().to(self.device) if args.dataset in ['flickr', 'reddit', 'ogbn-arxiv']: if self.args.adj: feat = x_s.data A_s = update_E(feat, 3) else: A_s = torch.sparse_coo_tensor(torch.stack([idx_s, idx_s], dim=0), torch.ones(Cond_size), torch.Size([Cond_size, Cond_size])).to(x_s.device) A_s = A_s.to(self.device) optimizer = torch.optim.Adam([x_s, y_s], lr=self.args.lr) best_val = 0 for it in trange(args.epochs): for i in range(TRAIN_K): x_train, label, sub_A_t = train_loader.get_batch(i) y_train_one_hot = F.one_hot(label.reshape(-1), n_class) x_train = x_train.to(self.device) y_train_one_hot = y_train_one_hot.to(self.device) sub_A_t = sub_A_t.to(self.device) _, _, training_loss, train_correct = self.train(KRR, x_train, x_s, y_train_one_hot, y_s, sub_A_t, A_s, MSEloss, optimizer, args.accumulate_steps, i, TRAIN_K) if it in args.checkpoints: # y_long = torch.argmax(y_s, dim=1) data.adj_syn, data.feat_syn, data.labels_syn = A_s.detach().to_dense(), x_s.detach(), y_s.detach() best_val = self.intermediate_evaluation(best_val, training_loss) else: E_train = sub_E(idx_train, adj).to(self.device) y_train_one_hot = y_train_one_hot.to(self.device) x_train = data.feat_train.to(self.device) if self.args.adj: feat = x_s.data A_s = update_E(feat, 4) else: A_s = torch.sparse_coo_tensor(torch.stack([idx_s, idx_s], dim=0), torch.ones(Cond_size), torch.Size([Cond_size, Cond_size])).to(x_s.device) best_val = 0 for it in trange(args.epochs): x_s, y_s, training_loss, training_acc = self.train(KRR, x_train, x_s, y_train_one_hot, y_s, E_train, A_s, MSEloss, optimizer) if it in args.checkpoints: # y_long = torch.argmax(y_s, dim=1) data.adj_syn, data.feat_syn, data.labels_syn = A_s.detach().to_dense(), x_s.detach(), y_s.detach() best_val = self.intermediate_evaluation(best_val, training_loss) return data