Source code for graphslim.dataset.attack

import os
import sys
import torch

import numpy as np

if os.path.abspath('..') not in sys.path:
    sys.path.append(os.path.abspath('..'))
from graphslim.config import *
from graphslim.dataset import *
import logging
from graphslim.models import *
import scipy.sparse as sp


[docs] def attack(data, args): seed_everything(args.seed) save_path = f'{args.save_path}/corrupt_graph/{args.attack}' gcn_model = GCN(nfeat=data.x.shape[1], nhid=args.hidden, nclass=data.nclass, args=args, mode='attack').to( args.device) if not os.path.exists(save_path): os.makedirs(save_path) if args.attack in ['metattack', 'random_adj']: if os.path.exists(f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz'): if args.setting == 'ind': data.adj_train = sp.load_npz( f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz') else: data.adj_full = sp.load_npz( f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz') print(f'load corrupt graph at {save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz') else: if args.setting == 'ind': adj = data.adj_train args.ptb_n = int(args.ptb_r * (adj.sum() // 2)) else: adj = data.adj_full args.ptb_n = int(args.ptb_r * (adj.sum() // 2)) if args.attack == 'metattack': if args.setting == 'ind': model = PRBCD(data, device=args.device) # ignore the test results! edge_index, _ = model.attack(ptb_rate=args.ptb_r) data.adj_train = ei2csr(edge_index.cpu(), data.num_nodes)[np.ix_(data.idx_train, data.idx_train)] sp.save_npz(f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz', data.adj_train) else: model = PRBCD(data, device=args.device) data.edge_index, _ = model.attack(data.edge_index, ptb_rate=args.ptb_r) data.adj_full = ei2csr(data.edge_index.cpu(), data.num_nodes) sp.save_npz(f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz', data.adj_full) elif args.attack == 'random_adj': model = RandomAttack() if args.setting == 'ind': model.attack(data.adj_train, n_perturbations=args.ptb_n, type='add') data.adj_train = model.modified_adj.tocsr() sp.save_npz(f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz', data.adj_train) else: model.attack(data.adj_full, n_perturbations=args.ptb_n, type='add') data.adj_full = model.modified_adj.tocsr() sp.save_npz(f'{save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz', data.adj_full) gcn_model.fit_with_val(data, train_iters=args.eval_epochs, verbose=args.verbose, setting=args.setting) test_acc = gcn_model.test(data, setting=args.setting, verbose=True) args.logger.info(f'attack {args.attack}_{args.ptb_r} test acc: {test_acc}') print(f'save corrupt graph at {save_path}/adj_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.npz') elif args.attack == 'random_feat': if os.path.exists(f'{save_path}/feat_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.pt'): if args.setting == 'ind': data.feat_train = torch.load(f'{args.save_path}/feat_{args.dataset}_{args.attack}_{args.ptb_r}.pt') else: data.feat_full = torch.load(f'{args.save_path}/feat_{args.dataset}_{args.attack}_{args.ptb_r}.pt') print(f'load corrupt graph at {args.save_path}/feat_{args.dataset}_{args.attack}_{args.ptb_r}.pt') else: model = RandomAttack(attack_structure=False, attack_features=True) args.ptb_n = int(args.ptb_r * data.x.shape[1]) if args.setting == 'ind': model.attack(data.feat_train, n_perturbations=args.ptb_n) data.feat_train = model.modified_features else: model.attack(data.feat_full, n_perturbations=args.ptb_n) data.feat_full = model.modified_features gcn_model.fit_with_val(data, train_iters=args.eval_epochs, verbose=args.verbose, setting=args.setting) test_acc = gcn_model.test(data, setting=args.setting, verbose=True) args.logger.info(f'attack {args.attack}_{args.ptb_r} test acc: {test_acc}') sp.save_npz(f'{args.save_path}/feat_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.pt', data.adj_train) print( f'save corrupt graph at {args.save_path}/feat_{args.dataset}_{args.attack}_{args.ptb_r}_{args.seed}.pt') return data