Source code for graphslim.sparsification.model_based_coreset_base

import numpy as np

from graphslim.dataset.utils import save_reduced
from graphslim.evaluation import *
from graphslim.models import *
from graphslim.sparsification.coreset_base import CoreSet
from graphslim.utils import to_tensor


[docs] class MBCoreSet(CoreSet): def __init__(self, setting, data, args, **kwarg): super(MBCoreSet, self).__init__(setting, data, args, **kwarg)
[docs] @verbose_time_memory def reduce(self, data, verbose=False, save=True): args = self.args model = eval(self.condense_model)(data.feat_full.shape[1], args.hidden, data.nclass, args).to( self.device) if self.setting == 'trans': if args.method in ['sfgc']: # model.fit_with_val(data, train_iters=1200, normadj=True, verbose=verbose, # setting=args.setting, reduced=False, final_output=True) # embeds = model.predict(data.feat_full, data.adj_full, output_layer_features=True)[0].detach() idx_selected = np.load(f'sparsification/fixed_idx/idx_{args.dataset}_{args.reduction_rate}_kcenter_15.npy') else: model.fit_with_val(data, train_iters=args.eval_epochs, normadj=True, verbose=verbose, setting=args.setting, reduced=False) embeds = model.predict(data.feat_full, data.adj_full).detach() idx_selected = self.select(embeds) data.adj_syn = data.adj_full[np.ix_(idx_selected, idx_selected)] data.feat_syn = data.feat_full[idx_selected] data.labels_syn = data.labels_full[idx_selected] if self.setting == 'ind': if args.method in ['sfgc']: idx_selected = np.load(f'sparsification/fixed_idx/idx_{args.dataset}_{args.reduction_rate}_kcenter_15.npy') else: model.fit_with_val(data, train_iters=args.eval_epochs, normadj=True, verbose=verbose, setting=args.setting, reduced=False) embeds = model.predict(data.feat_full, data.adj_full).detach() idx_selected = self.select(embeds) data.feat_syn = data.feat_train[idx_selected] data.adj_syn = data.adj_train[np.ix_(idx_selected, idx_selected)] data.labels_syn = data.labels_train[idx_selected] if verbose: print('selected nodes:', idx_selected.shape[0]) print('induced edges:', data.adj_syn.sum()) data.adj_syn, data.feat_syn, data.labels_syn = to_tensor(data.adj_syn, data.feat_syn, data.labels_syn, device='cpu') if save: save_reduced(data.adj_syn, data.feat_syn, data.labels_syn, args) # if args.method in ['sfgc', 'geom']: # # recover args # args.eval_epochs = epoch # args.weight_decay = wd # args.lr = lr return data