Source code for graphslim.sparsification.herding_agg

import numpy as np
import torch

from graphslim.sparsification.model_free_coreset_base import MFCoreSet


[docs] class HerdingAgg(MFCoreSet):
[docs] def select(self, embeds): idx_selected = [] for class_id, cnt in self.num_class_dict.items(): idx = self.idx_train[self.labels_train == class_id] features = embeds[idx] mean = torch.mean(features, dim=0, keepdim=True) selected = [] idx_left = np.arange(features.shape[0]).tolist() for i in range(cnt): det = mean * (i + 1) - torch.sum(features[selected], dim=0) dis = torch.cdist(det, features[idx_left]) id_min = torch.argmin(dis) selected.append(idx_left[id_min]) del idx_left[id_min] idx_selected.append(idx[selected]) return np.hstack(idx_selected)