Source code for graphslim.condensation.msgc

from torch import nn
from tqdm import trange

from collections import Counter
from sklearn.cluster import BisectingKMeans
from graphslim.condensation.gcond_base import GCondBase
from graphslim.condensation.utils import match_loss
from graphslim.dataset.utils import save_reduced
from graphslim.models import *
from torch_scatter import scatter_mean
from graphslim.utils import *


[docs] class MSGC(GCondBase): """ "Multiple sparse graphs condensation" https://www.sciencedirect.com/science/article/pii/S0950705123006548 """ def __init__(self, setting, data, args, **kwargs): super(MSGC, self).__init__(setting, data, args, **kwargs) x_channels = data.feat_train.shape[1] edge_hidden_channels = 256 self.n_syn = self.nnodes_syn n_each_y = self.generate_labels_syn(data) self.labels_syn = data.labels_syn = self.y_syn self.batch_size = args.batch_adj self.n_classes = data.nclass self.device = args.device self.num_class_dict = self.data.num_class_dict = {index: value for index, value in enumerate(n_each_y)} self.pge = nn.Sequential( nn.Linear(x_channels * 2, edge_hidden_channels), nn.BatchNorm1d(edge_hidden_channels), nn.ReLU(), nn.Linear(edge_hidden_channels, edge_hidden_channels), nn.BatchNorm1d(edge_hidden_channels), nn.ReLU(), nn.Linear(edge_hidden_channels, 1) ).to(args.device)
[docs] def generate_labels_syn(self, data): labels_train = data.labels_train.to(self.device) n = labels_train.shape[0] n_syn = self.nnodes_syn base = torch.ones(data.nclass, device=self.device) rate = F.one_hot(labels_train, num_classes=data.nclass).sum(0) / n n_each_y = torch.floor((n_syn - base.sum()) * rate) + base left = n_syn - n_each_y.sum() for _ in range(int(left.item())): more = n_each_y / n_each_y.sum() / rate n_each_y[more.argmin()] += 1 n_each_y = n_each_y.to(torch.int64) y_syn = torch.LongTensor(n_syn).to(self.device) start = 0 starts = torch.zeros_like(n_each_y) for c in range(data.nclass): y_syn[start:start + n_each_y[c]] = c starts[c] = start start += n_each_y[c] self.y_syn = y_syn if self.args.verbose: print(f'num_class:{n_each_y}') return n_each_y
[docs] def reduce(self, data, verbose=True): args = self.args if args.setting == 'trans': features, adj, labels = to_tensor(data.feat_full, data.adj_full, label=data.labels_full, device=self.device) else: features, adj, labels = to_tensor(data.feat_train, data.adj_train, label=data.labels_train, device=self.device) adj = normalize_adj_tensor(adj, sparse=True) y_syn = to_tensor(label=self.y_syn, device=self.device).repeat(self.batch_size) # assert args.condense_model != 'GAT' basic_model = eval(args.condense_model)(self.feat_syn.shape[1], args.hidden, data.nclass, args).to(self.device) self.reset_adj_batch() feat_init = self.init() self.feat_syn.data.copy_(feat_init) optimizer_x = torch.optim.Adam([self.feat_syn], lr=args.lr_feat) optimizer_adj = torch.optim.Adam(self.pge.parameters(), lr=args.lr_adj) best_val = 0 args.window = args.patience = 20 losses = FixLenList(args.window) x_syns = FixLenList(args.window) adj_t_syns = FixLenList(args.window) optimizer_basic_model = torch.optim.Adam(basic_model.parameters(), lr=args.lr) for it in trange(args.epochs): basic_model.initialize() basic_model = self.check_bn(basic_model) loss_avg = 0 for step_syn in range(args.outer_loop): basic_model = self.check_bn(basic_model) basic_model.eval() # fix basic_model while optimizing graphsyner ######################graph optimization##################################### self.adj_syn = self.get_adj_t_syn() loss = self.train_class(basic_model, adj, features, labels, y_syn, args) loss_avg += loss.item() optimizer_x.zero_grad() optimizer_adj.zero_grad() loss.backward() if it % 50 < 10: optimizer_adj.step() else: optimizer_x.step() x_syn = self.feat_syn.detach() adj_t_syn = self.get_adj_t_syn().detach() ################################################# for i in range(args.inner_loop): optimizer_basic_model.zero_grad() logits = basic_model(x_syn, adj_t_syn) inner_loss = F.nll_loss(logits, y_syn) inner_loss.backward() optimizer_basic_model.step() loss_avg /= (data.nclass * args.outer_loop) losses.append(loss_avg) x_syns.append(x_syn.clone()) adj_t_syns.append(adj_t_syn.clone()) loss_window = sum(losses.data) / len(losses.data) if args.verbose: print(f'loss:{loss_window:.4f} feat:{x_syn.sum().item():.4f} adj:{adj_t_syn.sum().item():.4f}') if it in args.checkpoints: best_x_syn = sum(x_syns.data) / len(x_syns.data) best_adj_t_syn = sum(adj_t_syns.data) / len(adj_t_syns.data) data.feat_syn, data.adj_syn, data.labels_syn = best_x_syn, best_adj_t_syn, y_syn best_val = self.intermediate_evaluation(best_val, loss_window) # if loss_window < smallest_loss: # patience = 0 # smallest_loss = loss_window # best_x_syn = sum(x_syns.data) / len(x_syns.data) # best_adj_t_syn = sum(adj_t_syns.data) / len(adj_t_syns.data) # # print(f'loss:{smallest_loss:.4f} feat:{x_syn.sum().item():.4f} adj:{adj_t_syn.sum().item():.4f}') # else: # patience += 1 # if patience >= args.patience: # break # save according to loss # data.feat_syn, data.adj_syn, data.labels_syn = best_x_syn, best_adj_t_syn, y_syn # best_val = self.intermediate_evaluation(0, loss_window) return data
# def init(self, with_adj=False): # n_classes = self.data.nclass # y_syn = self.y_syn # # cluster is restricted in training set in MSGC. # x_train = self.data.feat_train # y_train = self.data.labels_train # if self.init == 'cluster': # x_syn = torch.zeros(y_syn.shape[0], x_train.shape[1]) # for c in range(n_classes): # x_c = x_train[y_train == c].cpu() # n_c = (y_syn == c).sum().item() # k_means = BisectingKMeans(n_clusters=n_c, random_state=0) # k_means.fit(x_c) # clusters = torch.LongTensor(k_means.predict(x_c)) # x_syn[y_syn == c] = scatter_mean(x_c, clusters, dim=0) # return x_syn.to(x_train.device) # elif self.init == 'sample': # sam = SamplerForClass(y_train, n_classes) # counter = Counter(y_syn.cpu().numpy()) # idx_selected_list = [] # for c in range(n_classes): # idx_c = sam.sample_from_class(c, n_need_max=counter[c]) # idx_selected_list.append(idx_c) # idx_selected = torch.cat(idx_selected_list).to(x_train.device) # return x_train[idx_selected] # elif self.init == 'mean': # x_syn = torch.zeros(y_syn.shape[0], x_train.shape[1]).to(x_train.device) # for c in range(n_classes): # x_c = x_train[y_train == c] # n_c = (y_syn == c).sum() # x_syn[y_syn == c] = x_c.mean(0) # return x_syn
[docs] def reset_adj_batch(self): rows = [] cols = [] batch = [] for i in range(self.batch_size): n_neighbor = torch.zeros(self.n_syn, self.n_classes, device=self.device) index = torch.arange(self.n_syn, device=self.device) row = [] col = [] for row_id in range(self.n_syn): for c in range(self.n_classes): c_mask = self.y_syn == c c_mask[row_id] = False if c_mask.sum() == 0 or n_neighbor[row_id][c] > 1: continue link_coef = n_neighbor[c_mask, self.y_syn[row_id]] selected = link_coef.argmin() candidates_mask = link_coef == link_coef[selected] if candidates_mask.sum() == 1: col_id = index[c_mask][selected].item() else: candidates = index[c_mask][candidates_mask] col_id = candidates[torch.randperm(candidates.shape[0], device=self.device)[0]].item() n_neighbor[row_id, c] += 1 n_neighbor[col_id, self.y_syn[row_id]] += 1 row.append(row_id) row.append(col_id) col.append(col_id) col.append(row_id) rows.append(torch.LongTensor(row)) cols.append(torch.LongTensor(col)) batch.append(torch.LongTensor([i] * len(row))) self.rows = torch.cat(rows).to(self.device) self.cols = torch.cat(cols).to(self.device) self.batch = torch.cat(batch).to(self.device) n_edge = self.rows.shape[0] / self.batch_size / 2 if self.args.verbose: print(f'n_edge:{n_edge}')
[docs] def get_adj_t_syn(self): adj = torch.zeros(size=(self.batch_size, self.n_syn, self.n_syn), device=self.device) adj[self.batch, self.rows, self.cols] = torch.sigmoid(self.pge( torch.cat([self.feat_syn[self.rows], self.feat_syn[self.cols]], dim=1)).flatten()) adj = (torch.transpose(adj, 1, 2) + adj) / 2 adj += torch.eye(self.n_syn, device=self.device).view(1, self.n_syn, self.n_syn) deg = adj.sum(2) deg_inv = deg.pow(-1 / 2) deg_inv[torch.isinf(deg_inv)] = 0. adj = adj * deg_inv.view(self.batch_size, -1, 1) adj = adj * deg_inv.view(self.batch_size, 1, -1) return adj
[docs] class FixLenList: def __init__(self, lenth): self.lenth = lenth self.data = []
[docs] def append(self, element): self.data.append(element) if len(self.data) > self.lenth: del self.data[0]