Source code for graphslim.sparsification.coreset_base

from collections import Counter

import numpy as np


[docs] class CoreSet: def __init__(self, setting, data, args, **kwargs): self.data = data self.args = args self.setting = setting self.device = args.device if hasattr(data, 'labels_syn') and data.labels_syn is not None: self.num_class_dict = data.num_class_dict self.labels_train = data.labels_train if args.setting == 'ind': self.idx_train = np.arange(len(data.idx_train)) else: self.idx_train = data.idx_train else: self.num_class_dict, self.labels_train, self.idx_train = self.prepare_select(data, args) self.condense_model = 'GCN' # n = int(data.feat_train.shape[0] * args.reduction_rate)
[docs] def prepare_select(self, data, args): num_class_dict = {} syn_class_indices = {} if args.setting == 'ind': idx_train = np.arange(len(data.idx_train)) else: idx_train = data.idx_train labels_train = data.labels_train # d = data.feat_train.shape[1] counter = Counter(data.labels_train.tolist()) # n = len(data.labels_train) sorted_counter = sorted(counter.items(), key=lambda x: x[1]) sum_ = 0 labels_syn = [] for ix, (c, num) in enumerate(sorted_counter): num_class_dict[c] = max(int(num * args.reduction_rate), 1) sum_ += num_class_dict[c] syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]] labels_syn += [c] * num_class_dict[c] return num_class_dict, labels_train, idx_train