Source code for graphslim.sparsification.kcenter_sample

from graphslim.sparsification.model_based_coreset_base import MBCoreSet
import torch
import numpy as np
from collections import Counter
import random


[docs] class KCenterSample(MBCoreSet):
[docs] def select(self, embeds): idx_selected = [] for class_id, cnt in self.num_class_dict.items(): idx = self.idx_train[self.labels_train == class_id] feature = embeds[idx] mean = torch.mean(feature, dim=0, keepdim=True) dis = torch.cdist(feature, mean)[:, 0] rank = torch.argsort(dis) idx_centers = rank[:1].tolist() for i in range(cnt - 1): feature_centers = feature[idx_centers] dis_center = torch.cdist(feature, feature_centers) dis_min, _ = torch.min(dis_center, dim=-1) id_max = torch.argmax(dis_min).item() idx_centers.append(id_max) idx_selected.append(idx[idx_centers]) # return np.array(idx_selected).reshape(-1) return np.hstack(idx_selected)
# # def get_sub_adj_feat(self, features): # data = self.data # args = self.args # idx_selected = [] # # counter = Counter(self.data.labels_syn) # labels_train = self.data.labels_train.squeeze().tolist() # important # ids_per_cls_train = [(labels_train == c).nonzero()[0] for c in counter.keys()] # idx_selected = self.sampling(ids_per_cls_train, counter, features, 0.5, counter) # features = features[idx_selected] # # return features, None # # def sampling(self, ids_per_cls_train, budget, vecs, d): # budget_dist_compute = 1000 # ''' # if using_half: # vecs = vecs.half() # ''' # if isinstance(vecs, np.ndarray): # vecs = torch.from_numpy(vecs) # vecs = vecs.half() # ids_selected = [] # for i, ids in enumerate(ids_per_cls_train): # class_ = list(budget.keys())[i] # other_cls_ids = list(range(len(ids_per_cls_train))) # other_cls_ids.pop(i) # ids_selected0 = ids_per_cls_train[i] if len(ids_per_cls_train[i]) < budget_dist_compute else random.choices( # ids_per_cls_train[i], k=budget_dist_compute) # # dist = [] # vecs_0 = vecs[ids_selected0] # for j in other_cls_ids: # chosen_ids = random.choices(ids_per_cls_train[j], k=min(budget_dist_compute, len(ids_per_cls_train[j]))) # vecs_1 = vecs[chosen_ids] # if len(chosen_ids) < 26 or len(ids_selected0) < 26: # # torch.cdist throws error for tensor smaller than 26 # dist.append(torch.cdist(vecs_0.float(), vecs_1.float()).half()) # else: # dist.append(torch.cdist(vecs_0, vecs_1)) # # # dist = [torch.cdist(vecs[ids_selected0], vecs[random.choices(ids_per_cls_train[j], k=min(budget_dist_compute,len(ids_per_cls_train[j])))]) for j in other_cls_ids] # dist_ = torch.cat(dist, dim=-1) # include distance to all the other classes # n_selected = (dist_ < d).sum(dim=-1) # rank = n_selected.sort()[1].tolist() # current_ids_selected = rank[:budget[class_]] if len(rank) > budget[class_] else random.choices(rank, # k=budget[ # class_]) # ids_selected.extend([ids_per_cls_train[i][j] for j in current_ids_selected]) # return ids_selected