import os
from copy import deepcopy
import scipy
from graphslim.condensation.gcond_base import GCondBase
from graphslim.dataset.utils import save_reduced
from graphslim.evaluation.utils import verbose_time_memory
from graphslim.models import *
from graphslim.models.gntk import GNTK
from graphslim.models.reparam_module import ReparamModule
from graphslim.sparsification import *
from graphslim.utils import *
from tqdm import trange
from torch.optim import Adam, SGD
[docs]
class SFGC(GCondBase):
"""
"Structure-free Graph Condensation: From Large-scale Graphs to Condensed Graph-free Data." https://arxiv.org/pdf/2306.02664.pdf
"""
def __init__(self, setting, data, args, **kwargs):
super(SFGC, self).__init__(setting, data, args, **kwargs)
assert args.teacher_epochs + 100 >= args.expert_epochs
args.condense_model = 'GCN'
# args.init = 'kcenter'
self.buf_dir = '../sfgc_buffer/{}_{}_{}_{}'.format(args.dataset, args.attack, args.ptb_r, args.seed)
if not os.path.exists(self.buf_dir):
os.makedirs(self.buf_dir)
[docs]
@verbose_time_memory
def reduce(self, data, verbose=True):
args = self.args
# =============stage 1 trajectory save and load==================#
# can skip to save time
if not args.no_buff:
args.condense_model = 'GCN'
args.num_experts = 20 # 200
if args.setting == 'ind':
features, adj, labels = to_tensor(data.feat_train, data.adj_train, label=data.labels_train,
device=self.device)
else:
features, adj, labels = to_tensor(data.feat_full, data.adj_full, label=data.labels_full,
device=self.device)
adj = normalize_adj_tensor(adj, sparse=True)
device = args.device
trajectories = []
model = eval(args.condense_model)(features.shape[1], args.hidden, data.nclass, args).to(device)
for it in trange(args.num_experts):
# model.initialize()
model_parameters = list(model.parameters())
optimizer_model = eval(args.optim)(model_parameters, lr=args.lr_teacher,
weight_decay=args.wd_teacher)
timestamps = []
timestamps.append([p.detach().cpu() for p in model.parameters()])
for e in range(args.teacher_epochs):
model.train()
optimizer_model.zero_grad()
output = model.forward(features, adj)
if args.setting == 'ind':
loss_buffer = F.nll_loss(output, labels)
else:
loss_buffer = F.nll_loss(output[data.idx_train], labels[data.idx_train])
loss_buffer.backward()
optimizer_model.step()
if e % 10 == 0 and e > 1:
timestamps.append([p.detach().cpu() for p in model.parameters()])
trajectories.append(timestamps)
# need too many space to save
if len(trajectories) == 10:
n = 0
while os.path.exists(os.path.join(self.buf_dir, "replay_buffer_{}.pt".format(n))):
n += 1
print("Saving {}".format(os.path.join(self.buf_dir, "replay_buffer_{}.pt".format(n))))
torch.save(trajectories, os.path.join(self.buf_dir, "replay_buffer_{}.pt".format(n)))
trajectories = []
# =============stage 2 trajectory alignment and GCN evaluation==================#
# kcenter select
feat_init, adj_init = self.init(with_adj=True)
self.feat_syn.data.copy_(feat_init)
labels_syn = to_tensor(label=data.labels_syn, device=self.device)
self.adj_syn_init = adj_init
file_idx, expert_idx, expert_files = self.expert_load(self.buf_dir)
syn_lr = torch.tensor(args.lr_student).float()
syn_lr = syn_lr.detach().to(self.device).requires_grad_(True)
optimizer_lr = torch.optim.SGD([syn_lr], lr=1e-6, momentum=0.5)
best_val = 0
bar = trange(args.epochs, ncols=100)
for it in bar:
model = eval(args.condense_model)(data.feat_train.shape[1], args.hidden, data.nclass, args).to(self.device)
model = ReparamModule(model)
model.train()
num_params = sum([np.prod(p.size()) for p in (model.parameters())])
expert_trajectory = self.buffer[expert_idx]
expert_idx += 1
if expert_idx == len(self.buffer):
expert_idx = 0
file_idx += 1
if file_idx == len(expert_files):
file_idx = 0
random.shuffle(expert_files)
del self.buffer
self.buffer = torch.load(expert_files[file_idx])
random.shuffle(self.buffer)
start = np.linspace(0, args.start_epoch, num=args.start_epoch // 10 + 1)
start_epoch = int(np.random.choice(start, 1)[0])
if args.optim == 'Adam':
start_epoch = start_epoch // 10
starting_params = expert_trajectory[start_epoch]
target_params = expert_trajectory[start_epoch + args.expert_epochs // 10]
target_params = torch.cat([p.data.to(self.device).reshape(-1) for p in target_params], 0)
student_params = [
torch.cat([p.data.to(self.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)]
starting_params = torch.cat([p.data.to(self.device).reshape(-1) for p in starting_params], 0)
param_loss_list = []
param_dist_list = []
if it == 0:
feat_syn = self.feat_syn
adj_syn_norm = normalize_adj_tensor(self.adj_syn_init, sparse=True)
adj_syn_input = to_tensor(adj_syn_norm, device=self.device)
else:
feat_syn = self.feat_syn
adj_syn = torch.eye(feat_syn.shape[0], device=self.device)
adj_syn_cal_norm = normalize_adj_tensor(adj_syn, sparse=False)
adj_syn_input = adj_syn_cal_norm
for step in range(args.syn_steps):
forward_params = student_params[-1]
output_syn = model.forward(feat_syn, adj_syn_input, flat_param=forward_params)
loss_syn = F.nll_loss(output_syn, labels_syn)
grad = torch.autograd.grad(loss_syn, student_params[-1], create_graph=True)[0]
student_params.append(student_params[-1] - syn_lr * grad)
param_loss = torch.tensor(0.0).to(self.device)
param_dist = torch.tensor(0.0).to(self.device)
param_loss += F.mse_loss(student_params[-1], target_params, reduction="sum")
param_dist += F.mse_loss(starting_params, target_params, reduction="sum")
param_loss_list.append(param_loss)
param_dist_list.append(param_dist)
param_loss /= num_params
param_dist /= num_params
param_loss /= param_dist
total_loss = param_loss
self.optimizer_feat.zero_grad()
optimizer_lr.zero_grad()
total_loss.backward()
self.optimizer_feat.step()
optimizer_lr.step()
if torch.isnan(total_loss):
break # Break out of the loop if either is NaN
bar.set_postfix_str(
f"File ID = {file_idx} Total_Loss = {total_loss.item():.4f} Syn_Lr = {syn_lr.item():.4f}")
if it in args.checkpoints:
data.adj_syn, data.feat_syn, data.labels_syn = torch.eye(
feat_syn.shape[0]), feat_syn.detach(), labels_syn.detach()
best_val = self.intermediate_evaluation(best_val, total_loss.item())
for _ in student_params:
del _
return data
[docs]
def expert_load(self, expert_dir):
'''
randomly select one expert from expert files
'''
expert_files = []
n = 0
while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):
expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))
n += 1
if n == 0:
raise AssertionError("No buffers detected at {}".format(expert_dir))
file_idx = 0
expert_idx = 0
random.shuffle(expert_files)
# print("loading file {}".format(expert_files[file_idx]))
buffer = torch.load(expert_files[file_idx])
random.shuffle(buffer)
self.buffer = buffer
return file_idx, expert_idx, expert_files