from itertools import product
import numpy as np
import scipy
import torch
import torch.nn.functional as F
import torch_sparse
from torch import nn
# from graph import mx_inv, mx_inv_sqrt, mx_tr
[docs]
def mx_inv(mx, device):
if isinstance(mx, torch_sparse.tensor.SparseTensor):
mx = mx.to_dense()
elif isinstance(mx, scipy.sparse.csr.csr_matrix):
mx = torch.FloatTensor(mx.toarray()).to(device)
U, D, V = torch.svd(mx)
eps = 0.009
D_min = torch.min(D)
if D_min < eps:
D_1 = torch.zeros_like(D)
D_1[D > D_min] = 1 / D[D > D_min]
else:
D_1 = 1 / D
# D_1 = 1 / D #.clamp(min=0.005)
return U @ D_1.diag() @ V.t()
[docs]
def mx_inv_sqrt(mx):
# singular values need to be distinct for backprop
U, D, V = torch.svd(mx)
D_min = torch.min(D)
eps = 0.009
if D_min < eps:
D_1 = torch.zeros_like(D)
D_1[D > D_min] = 1 / D[D > D_min] # .sqrt()
else:
D_1 = 1 / D # .sqrt()
# D_1 = 1 / D.clamp(min=0.005).sqrt()
return U @ D_1.sqrt().diag() @ V.t(), U @ D_1.diag() @ V.t()
[docs]
def mx_tr(mx):
return mx.diag().sum()
[docs]
def get_mgrid(sidelen, dim=2):
if isinstance(sidelen, int):
sidelen = dim * (sidelen,)
if dim == 2:
pixel_coords = np.stack(np.mgrid[: sidelen[0], : sidelen[1]], axis=-1)[
None, ...
].astype(np.float32)
pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1)
pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1)
elif dim == 3:
pixel_coords = np.stack(
np.mgrid[: sidelen[0], : sidelen[1], : sidelen[2]], axis=-1
)[None, ...].astype(np.float32)
pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
else:
raise NotImplementedError("Not implemented for dim=%d" % dim)
pixel_coords -= 0.5
pixel_coords *= 2.0
pixel_coords = torch.Tensor(pixel_coords).view(-1, dim)
return pixel_coords.to(torch.float32)
[docs]
class Sine(nn.Module):
def __init__(self):
super().__init__()
[docs]
def forward(self, input):
return torch.sin(30 * input)
[docs]
class EdgeBlock(nn.Module):
def __init__(self, in_, out_, dtype=torch.float32) -> None:
super(EdgeBlock, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_, out_, dtype=dtype),
nn.BatchNorm1d(out_, dtype=dtype),
nn.ReLU(),
)
[docs]
def forward(self, x):
return self.net(x)
[docs]
class GraphonLearner(nn.Module):
def __init__(
self,
node_feature,
nfeat=256,
nnodes=50,
device="cuda",
args={},
num_hidden_layers=3,
**kwargs
):
super().__init__()
self.num_hidden_layers = num_hidden_layers
self.step_size = nnodes
self.ep_ratio = args.ep_ratio
self.sinkhorn_iter = args.sinkhorn_iter
self.mx_size = args.mx_size
self.edge_index = np.array(
list(product(range(self.step_size), range(self.step_size)))
).T
self.net0 = nn.ModuleList(
[
EdgeBlock(node_feature * 2, nfeat),
EdgeBlock(nfeat, nfeat),
nn.Linear(nfeat, 1, dtype=torch.float32),
]
)
self.net1 = nn.ModuleList(
[
EdgeBlock(2, nfeat),
EdgeBlock(nfeat, nfeat),
nn.Linear(nfeat, 1, dtype=torch.float32),
]
)
self.P = nn.Parameter(
torch.Tensor(self.mx_size, self.step_size).to(torch.float32).uniform_(0, 1)
) # transport plan
self.Lx_inv = None
self.output = nn.Linear(nfeat, 1)
self.act = F.relu
self.device = device
self.reset_parameters()
[docs]
def reset_parameters(self):
def weight_reset(m):
if isinstance(m, nn.Linear):
m.reset_parameters()
if isinstance(m, nn.BatchNorm1d):
m.reset_parameters()
self.apply(weight_reset)
[docs]
def forward(self, c, inference=False, Lx=None):
if inference == True:
self.eval()
else:
self.train()
x0 = get_mgrid(c.shape[0]).to(self.device)
c = torch.cat([c[self.edge_index[0]], c[self.edge_index[1]]], axis=1)
for layer in range(len(self.net0)):
c = self.net0[layer](c)
if layer == 0:
x = self.net1[layer](x0)
else:
x = self.net1[layer](x)
if layer != (len(self.net0) - 1):
# use node feature to guide the graphon generating process
x = x * c
else:
x = (1 - self.ep_ratio) * x + self.ep_ratio * c
# x = self.output(x)
# adj = self.output(x).reshape(self.step_size, self.step_size)
adj = x.reshape(self.step_size, self.step_size)
adj = (adj + adj.T) / 2
adj = torch.sigmoid(adj)
adj = adj - torch.diag(torch.diag(adj, 0))
if inference == True:
return adj
if Lx is not None and self.Lx_inv is None:
self.Lx_inv = mx_inv(Lx, c.device)
# try:
# opt_loss = self.opt_loss(adj)
# except Exception as e:
# print(e)
# opt_loss = torch.tensor(0).to(self.device)
opt_loss = self.opt_loss(adj)
return adj, opt_loss
[docs]
def opt_loss(self, adj):
Ly_inv_rt, Ly_inv = mx_inv_sqrt(adj)
m = self.step_size
P = self.P.abs()
for _ in range(self.sinkhorn_iter):
P = P / P.sum(dim=1, keepdim=True)
P = P / P.sum(dim=0, keepdim=True)
# if self.args.use_symeig:
# sqrt = torch.symeig(
sqrt = torch.linalg.eigh(
Ly_inv_rt @ self.P.t() @ self.Lx_inv @ self.P @ Ly_inv_rt
)
loss = torch.abs(
mx_tr(Ly_inv) * m - 2 * torch.sqrt(sqrt[0].clamp(min=2e-20)).sum()
)
return loss
@torch.no_grad()
def inference(self, c):
return self.forward(c, inference=True)