Source code for graphslim.sparsification.model_free_coreset_base

import numpy as np
import torch
from torch_sparse import matmul

from graphslim.dataset.utils import save_reduced
from graphslim.evaluation.utils import verbose_time_memory
from graphslim.sparsification.coreset_base import CoreSet
from graphslim.utils import normalize_adj_tensor, to_tensor


[docs] class MFCoreSet(CoreSet): def __init__(self, setting, data, args, **kwargs): super(MFCoreSet, self).__init__(setting, data, args, **kwargs)
[docs] @verbose_time_memory def reduce(self, data, verbose=False, save=True): args = self.args if self.setting == 'trans': if args.agg: data.adj_fully = to_tensor(data.adj_full) data.pre_conv = normalize_adj_tensor(data.adj_fully, sparse=True) data.pre_conv = matmul(data.pre_conv, data.pre_conv) feat_agg = matmul(data.pre_conv, data.feat_full).float() idx_selected = self.select(feat_agg) data.feat_syn = feat_agg[idx_selected] data.adj_syn = torch.eye(data.feat_syn.shape[0], device=args.device) data.labels_syn = data.labels_full[idx_selected] else: idx_selected = self.select() data.adj_syn = data.adj_full[np.ix_(idx_selected, idx_selected)] data.feat_syn = data.feat_full[idx_selected] data.labels_syn = data.labels_full[idx_selected] if self.setting == 'ind': if args.agg: data.adj_fully = to_tensor(data.adj_train) data.pre_conv = normalize_adj_tensor(data.adj_fully, sparse=True) data.pre_conv = matmul(data.pre_conv, data.pre_conv) feat_agg = matmul(data.pre_conv, data.feat_full).float() idx_selected = self.select(feat_agg) data.feat_syn = feat_agg[idx_selected] data.adj_syn = torch.eye(data.feat_syn.shape[0], device=args.device) data.labels_syn = data.labels_train[idx_selected] else: idx_selected = self.select() data.feat_syn = data.feat_train[idx_selected] data.adj_syn = data.adj_train[np.ix_(idx_selected, idx_selected)] data.labels_syn = data.labels_train[idx_selected] if verbose: print('selected nodes:', idx_selected.shape[0]) print('induced edges:', data.adj_syn.sum()) data.adj_syn, data.feat_syn, data.labels_syn = to_tensor(data.adj_syn, data.feat_syn, label=data.labels_syn, device='cpu') if save: save_reduced(data.adj_syn, data.feat_syn, data.labels_syn, args) return data