import json
import os.path as osp
import os
import pickle
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from torch_geometric.datasets import Planetoid, Coauthor, CitationFull, Amazon, Flickr, Reddit2
from torch_geometric.loader import NeighborSampler
from torch_geometric.utils import to_undirected, add_self_loops
from torch_sparse import SparseTensor
from dgl.data import FraudDataset
import shutil
from graphslim.dataset.convertor import ei2csr, csr2ei, from_dgl
from graphslim.dataset.utils import splits
from graphslim.utils import index_to_mask, to_tensor
try:
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr # type: ignore
except ImportError:
DataEdgeAttr = None # type: ignore
DataTensorAttr = None # type: ignore
try:
from torch_geometric.data.storage import GlobalStorage # type: ignore
except ImportError:
GlobalStorage = None # type: ignore
try:
import torch.serialization as torch_serialization
except ImportError:
torch_serialization = None
[docs]
def get_dataset(name='cora', args=None, load_path='./data'):
path = osp.join(load_path)
# Create a dictionary that maps standard names to normalized names
standard_names = ['flickr', 'reddit', 'dblp', 'cora_ml', 'physics', 'cs', 'cora', 'citeseer', 'pubmed', 'photo',
'computers', 'ogbn-products', 'ogbn-proteins', 'ogbn-papers100m', 'ogbn-arxiv', 'yelp', 'amazon']
normalized_names = [name.lower().replace('-', '').replace('_', '') for name in standard_names]
name_dict = dict(zip(normalized_names, standard_names))
# Normalize the name input
normalized_name = name.lower().replace('-', '').replace('_', '')
if normalized_name in name_dict:
name = name_dict[normalized_name] # Transfer to standard name
if name in ['flickr']:
dataset = Flickr(root=path + '/flickr')
elif name in ['reddit']:
dataset = Reddit2(root=path + '/reddit')
elif name in ['dblp', 'cora_ml', 'cora_full', 'citeseer_full']:
dataset = CitationFull(root=path, name=name)
elif name in ['physics', 'cs']:
dataset = Coauthor(root=path, name=name)
elif name in ['cora', 'citeseer', 'pubmed']:
dataset = Planetoid(root=path, name=name)
elif name in ['photo', 'computers']:
dataset = Amazon(root=path, name=name)
elif name in ['ogbn-arxiv']:
dataset = DataGraphSAINT(root=path, dataset=name)
dataset.num_classes = 40
elif name in ['ogbn-products', 'ogbn-proteins', 'ogbn-papers100m']:
dataset = PygNodePropPredDataset(name, root=path)
elif name in ['yelp', 'amazon']:
# dataset = pickle.load(open(f'{path}/{args.dataset}.dat', 'rb'))
# dataset.num_classes = 2
dataset = FraudDataset(name, raw_dir=path)
dataset = from_dgl(dataset[0], name=name, hetero=False) # dgl2pyg
else:
raise ValueError("Dataset name not recognized.")
try:
data = dataset[0]
except:
data = dataset
# pyg2TransAndInd: add splits
data = splits(data, args.split)
data = TransAndInd(data, name, args.pre_norm)
try:
data.nclass = dataset.num_classes
except:
data.nclass = data.num_classes
print("train nodes num:", sum(data.train_mask).item())
print("val nodes num:", sum(data.val_mask).item())
print("test nodes num:", sum(data.test_mask).item())
print("total nodes num:", data.x.shape[0])
return data
[docs]
class TransAndInd:
def __init__(self, data, dataset, norm=True):
self.class_dict = None # sample the training data per class when initializing synthetic graph
self.samplers = None
self.class_dict2 = None # sample from the same class when training
self.sparse_adj = None
self.adj_full = None
self.feat_full = None
self.labels_full = None
self.num_nodes = data.num_nodes
self.train_mask, self.val_mask, self.test_mask = data.train_mask, data.val_mask, data.test_mask
self.pyg_saint(data)
if dataset in ['flickr', 'reddit', 'ogbn-arxiv']:
self.edge_index = to_undirected(self.edge_index, self.num_nodes)
feat_train = self.x[data.idx_train]
scaler = StandardScaler()
scaler.fit(feat_train)
self.feat_full = scaler.transform(self.x)
self.feat_full = torch.from_numpy(self.feat_full).float()
if norm and dataset in ['cora', 'citeseer', 'pubmed']:
self.feat_full = F.normalize(self.feat_full, p=1, dim=1)
self.idx_train, self.idx_val, self.idx_test = data.idx_train, data.idx_val, data.idx_test
# self.nclass = max(self.labels_full).item() + 1
self.adj_train = self.adj_full[np.ix_(self.idx_train, self.idx_train)]
self.adj_val = self.adj_full[np.ix_(self.idx_val, self.idx_val)]
self.adj_test = self.adj_full[np.ix_(self.idx_test, self.idx_test)]
self.labels_train = self.labels_full[self.idx_train]
self.labels_val = self.labels_full[self.idx_val]
self.labels_test = self.labels_full[self.idx_test]
self.feat_train = self.feat_full[self.idx_train]
self.feat_val = self.feat_full[self.idx_val]
self.feat_test = self.feat_full[self.idx_test]
[docs]
def to(self, device):
"""Move data to the specified device."""
self.feat_full = self.feat_full.to(device)
self.labels_full = self.labels_full.to(device)
self.x = self.x.to(device)
self.y = self.y.to(device)
self.edge_index = self.edge_index.to(device)
self.feat_train = self.feat_train.to(device)
self.feat_val = self.feat_val.to(device)
self.feat_test = self.feat_test.to(device)
# self.labels_train = self.labels_train.to(device)
# self.labels_val = self.labels_val.to(device)
# self.labels_test = self.labels_test.to(device)
return self
[docs]
def pyg_saint(self, data):
# reference type
# pyg format use x,y,edge_index
if hasattr(data, 'x'):
self.x = data.x
self.y = data.y
self.feat_full = data.x
self.labels_full = data.y
self.adj_full = ei2csr(data.edge_index, data.x.shape[0])
self.edge_index = data.edge_index
self.sparse_adj = SparseTensor.from_edge_index(data.edge_index)
# saint format use feat,labels,adj
elif hasattr(data, 'feat_full'):
self.adj_full = data.adj_full
self.feat_full = data.feat_full
self.labels_full = data.labels_full
self.edge_index = csr2ei(data.adj_full)
self.sparse_adj = SparseTensor.from_edge_index(self.edge_index)
self.x = data.feat_full
self.y = data.labels_full
return data
[docs]
def retrieve_class(self, c, num=256):
# change the initialization strategy here
if self.class_dict is None:
self.class_dict = {}
for i in range(self.nclass):
self.class_dict['class_%s' % i] = (self.labels_train == i)
idx = np.arange(len(self.labels_train))
idx = idx[self.class_dict['class_%s' % c]]
return np.random.permutation(idx)[:num]
[docs]
def retrieve_class_sampler(self, c, adj, args, num=256):
if self.class_dict2 is None:
self.class_dict2 = {}
for i in range(self.nclass):
if args.setting == 'trans':
idx = self.idx_train[self.labels_train == i]
else:
idx = np.arange(len(self.labels_train))[self.labels_train == i]
self.class_dict2[i] = idx
if args.nlayers == 1:
sizes = [15]
if args.nlayers == 2:
if args.dataset in ['reddit', 'flickr']:
sizes = [15, 8]
else:
sizes = [10, 5]
# sizes = [-1, -1]
if args.nlayers == 3:
sizes = [15, 10, 5]
if args.nlayers == 4:
sizes = [15, 10, 5, 5]
if args.nlayers == 5:
sizes = [15, 10, 5, 5, 5]
if self.samplers is None:
self.samplers = []
for i in range(self.nclass):
node_idx = torch.LongTensor(self.class_dict2[i])
self.samplers.append(NeighborSampler(adj,
node_idx=node_idx,
sizes=sizes, batch_size=num,
num_workers=8, return_e_id=False,
num_nodes=adj.size(0),
shuffle=True))
batch = np.random.permutation(self.class_dict2[c])[:num]
out = self.samplers[c].sample(batch.astype(np.int64))
return out
[docs]
def reset(self):
self.samplers = None
self.class_dict2 = None
self.labels_syn, self.feat_syn, self.adj_syn = None, None, None
[docs]
class LargeDataLoader(nn.Module):
def __init__(self, name='Flickr', split='train', batch_size=200, split_method='kmeans'):
super(LargeDataLoader, self).__init__()
path = osp.join('./data')
if name in ['ogbn-arxiv']:
dataset = DataGraphSAINT(root=path, dataset=name)
dataset.num_classes = 40
data = dataset[0]
self.n, self.dim = data.feat_full.shape
labels = data.labels_full
features = to_tensor(data.feat_full)
edge_index = csr2ei(data.adj_full)
values = torch.ones(edge_index.shape[1])
Adj = torch.sparse_coo_tensor(edge_index, values, torch.Size([self.n, self.n]))
sparse_eye = torch.sparse_coo_tensor(torch.arange(self.n).repeat(2, 1), torch.ones(self.n),
(self.n, self.n))
self.Adj = Adj + sparse_eye
features = self.normalize_data(features)
features = self.GCF(self.Adj, features, k=1)
self.split_idx = torch.tensor(data.idx_train)
self.n_split = len(self.split_idx)
self.k = torch.round(torch.tensor(self.n_split / batch_size)).to(torch.int)
self.split_feat = features[self.split_idx]
self.split_label = labels[self.split_idx]
self.split_method = split_method
self.n_classes = dataset.num_classes
else:
if name == 'flickr':
from torch_geometric.datasets import Flickr as DataSet
elif name == 'reddit':
from torch_geometric.datasets import Reddit2 as DataSet
Dataset = DataSet(root=path + f'/{name}')
self.n, self.dim = Dataset[0].x.shape
mask = split + '_mask'
features = Dataset[0].x
labels = Dataset[0].y
edge_index = Dataset[0].edge_index
values = torch.ones(edge_index.shape[1])
Adj = torch.sparse_coo_tensor(edge_index, values, torch.Size([self.n, self.n]))
sparse_eye = torch.sparse_coo_tensor(torch.arange(self.n).repeat(2, 1), torch.ones(self.n),
(self.n, self.n))
self.Adj = Adj + sparse_eye
features = self.normalize_data(features)
# features = self.GCF(self.Adj, features, k=2)
self.split_idx = torch.where(Dataset[0][mask])[0]
self.n_split = len(self.split_idx)
self.k = torch.round(torch.tensor(self.n_split / batch_size)).to(torch.int)
# Masked Adjacency Matrix
optor_index = torch.cat(
(self.split_idx.reshape(1, self.n_split), torch.tensor(range(self.n_split)).reshape(1, self.n_split)),
dim=0)
optor_value = torch.ones(self.n_split)
optor_shape = torch.Size([self.n, self.n_split])
optor = torch.sparse_coo_tensor(optor_index, optor_value, optor_shape)
self.Adj_mask = torch.sparse.mm(torch.sparse.mm(optor.t(), self.Adj), optor)
self.split_feat = features[self.split_idx]
# self.split_feat = self.GCF(self.Adj_mask, self.split_feat, k = 2)
self.split_label = labels[self.split_idx]
self.split_method = split_method
self.n_classes = Dataset.num_classes
[docs]
def normalize_data(self, data):
"""
normalize data
parameters:
data: torch.Tensor, data need to be normalized
return:
torch.Tensor, normalized data
"""
mean = data.mean(dim=0)
std = data.std(dim=0)
std[std == 0] = 1
normalized_data = (data - mean) / std
return normalized_data
[docs]
def GCF(self, adj, x, k=2):
"""
Graph convolution filter
parameters:
adj: torch.Tensor, adjacency matrix, must be self-looped
x: torch.Tensor, features
k: int, number of hops
return:
torch.Tensor, filtered features
"""
n = adj.shape[0]
ind = torch.tensor(range(n)).repeat(2, 1)
adj = adj + torch.sparse_coo_tensor(ind, torch.ones(n), (n, n))
D = torch.pow(torch.sparse.sum(adj, 1).to_dense(), -0.5)
D = torch.sparse_coo_tensor(ind, D, (n, n))
filter = torch.sparse.mm(torch.sparse.mm(D, adj), D)
for i in range(k):
x = torch.sparse.mm(filter, x)
return x
[docs]
def properties(self):
return self.k, self.n_split, self.n_classes, self.dim, self.n
[docs]
def split_batch(self):
"""
split data into batches
parameters:
split_method: str, method to split data, default is 'kmeans'
"""
if self.split_method == 'kmeans':
kmeans = KMeans(n_clusters=self.k.item(), n_init=10)
kmeans.fit(self.split_feat.numpy())
self.batch_labels = kmeans.predict(self.split_feat.numpy())
[docs]
def getitem(self, idx):
"""
对于给定的 idx 输出对应的 node_features, labels, sub Ajacency matrix
"""
# idx = [idx]
n_idx = len(idx)
idx_raw = self.split_idx[idx]
feat = self.split_feat[idx]
label = self.split_label[idx]
# idx = idx.tolist()
optor_index = torch.cat((idx_raw.reshape(1, n_idx), torch.tensor(range(n_idx)).reshape(1, n_idx)), dim=0)
optor_value = torch.ones(n_idx)
optor_shape = torch.Size([self.n, n_idx])
optor = torch.sparse_coo_tensor(optor_index, optor_value, optor_shape)
sub_A = torch.sparse.mm(torch.sparse.mm(optor.t(), self.Adj), optor)
return (feat, label, sub_A)
[docs]
def get_batch(self, i):
idx = torch.where(torch.tensor(self.batch_labels) == i)[0]
batch_i = self.getitem(idx)
return batch_i
[docs]
class OgbDataLoader(nn.Module):
def __init__(self, dataset_name='ogbn-arxiv', split='train', batch_size=5000, split_method='kmeans'):
super(OgbDataLoader, self).__init__()
[docs]
class DataGraphSAINT:
'''datasets used in GraphSAINT paper'''
def __init__(self, root, dataset, **kwargs):
import gdown
dataset = dataset.replace('-', '_')
dataset_str = root + '/' + dataset + '/raw/'
if not osp.exists(dataset_str):
os.makedirs(dataset_str)
required_files = ['adj_full.npz', 'role.json', 'feats.npy', 'class_map.json']
missing_files = [fname for fname in required_files if not osp.exists(osp.join(dataset_str, fname))]
if missing_files:
print('Downloading dataset')
url = 'https://drive.google.com/drive/folders/1VDobXR5KqKoov6WhYXFMwH4rN0FMnVOa' # Change this to your actual file ID
try:
downloaded_items = gdown.download_folder(url=url, output=dataset_str, quiet=False)
except Exception as exc:
print(f'gdown download failed: {exc}')
downloaded_items = None
if downloaded_items:
if isinstance(downloaded_items, str):
downloaded_items = [downloaded_items]
for item_path in downloaded_items:
if not isinstance(item_path, str):
continue
if osp.isdir(item_path):
for filename in os.listdir(item_path):
src = osp.join(item_path, filename)
dst = osp.join(dataset_str, filename)
if osp.abspath(src) == osp.abspath(dst):
continue
shutil.move(src, dst)
if osp.abspath(item_path) != osp.abspath(dataset_str):
shutil.rmtree(item_path)
elif osp.isfile(item_path):
dst = osp.join(dataset_str, osp.basename(item_path))
if osp.abspath(item_path) != osp.abspath(dst):
shutil.move(item_path, dst)
missing_files = [fname for fname in required_files if not osp.exists(osp.join(dataset_str, fname))]
if missing_files and dataset == 'ogbn_arxiv':
print('Falling back to OGB loader for ogbn-arxiv.')
self._prepare_ogbn_arxiv_from_ogb(dataset_str, root)
missing_files = [fname for fname in required_files if not osp.exists(osp.join(dataset_str, fname))]
if missing_files:
raise RuntimeError(f'Unable to prepare dataset files: {missing_files}')
if dataset == 'ogbn_arxiv':
self.adj_full = sp.load_npz(dataset_str + 'adj_full.npz')
self.adj_full = self.adj_full + self.adj_full.T
self.adj_full[self.adj_full > 1] = 1
self.num_nodes = self.adj_full.shape[0]
role = json.load(open(dataset_str + 'role.json', 'r'))
self.idx_train = role['tr']
self.idx_test = role['te']
self.idx_val = role['va']
self.train_mask = index_to_mask(self.idx_train, self.num_nodes)
self.test_mask = index_to_mask(self.idx_test, self.num_nodes)
self.val_mask = index_to_mask(self.idx_val, self.num_nodes)
self.feat_full = np.load(dataset_str + 'feats.npy')
# ---- normalize feat ----
class_map = json.load(open(dataset_str + 'class_map.json', 'r'))
self.labels_full = to_tensor(label=self.process_labels(class_map))
def _prepare_ogbn_arxiv_from_ogb(self, dataset_str, root):
from ogb.nodeproppred import PygNodePropPredDataset
if torch_serialization is not None:
add_safe_globals = getattr(torch_serialization, 'add_safe_globals', None)
if callable(add_safe_globals):
safe_classes = []
if DataEdgeAttr is not None:
safe_classes.append(DataEdgeAttr)
if DataTensorAttr is not None:
safe_classes.append(DataTensorAttr)
if GlobalStorage is not None:
safe_classes.append(GlobalStorage)
if safe_classes:
add_safe_globals(safe_classes)
ogb_root = osp.join(root, 'ogb_cache')
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root=ogb_root)
data = dataset[0]
split_idx = dataset.get_idx_split()
features = data.x.cpu().numpy().astype(np.float32)
labels = data.y.view(-1).cpu().numpy().astype(np.int64)
edge_index = data.edge_index.cpu().numpy()
num_nodes = features.shape[0]
values = np.ones(edge_index.shape[1], dtype=np.float32)
adj = sp.coo_matrix((values, (edge_index[0], edge_index[1])), shape=(num_nodes, num_nodes))
adj = adj.tocsr()
adj = adj + adj.T
adj[adj > 1] = 1
sp.save_npz(osp.join(dataset_str, 'adj_full.npz'), adj)
np.save(osp.join(dataset_str, 'feats.npy'), features)
role = {
'tr': split_idx['train'].cpu().tolist(),
'va': split_idx['valid'].cpu().tolist(),
'te': split_idx['test'].cpu().tolist(),
}
with open(osp.join(dataset_str, 'role.json'), 'w') as f:
json.dump(role, f)
class_map = {str(idx): int(label) for idx, label in enumerate(labels)}
with open(osp.join(dataset_str, 'class_map.json'), 'w') as f:
json.dump(class_map, f)
[docs]
def process_labels(self, class_map):
"""
setup vertex property map for output classests
"""
num_vertices = self.num_nodes
if isinstance(list(class_map.values())[0], list):
num_classes = len(list(class_map.values())[0])
self.nclass = num_classes
class_arr = np.zeros((num_vertices, num_classes))
for k, v in class_map.items():
class_arr[int(k)] = v
else:
class_arr = np.zeros(num_vertices, dtype=np.int64)
for k, v in class_map.items():
class_arr[int(k)] = v
class_arr = class_arr - class_arr.min()
self.nclass = max(class_arr) + 1
return class_arr
[docs]
def get(self, idx):
return self
def __getitem__(self, idx):
return self.get(idx)