import math
import torch
import torch.nn as nn
[docs]
class StructureBasedNeuralTangentKernel(nn.Module):
def __init__(self, K=2, L=2, scale='add'):
super(StructureBasedNeuralTangentKernel, self).__init__()
self.K = K
self.L = L
self.scale = scale
[docs]
def sparse_kron(self, A, B):
"""
A, B: torch.sparse.FloatTensor of shape (m, n) and (p, q)
Returns: the Kronecker product of A and B
"""
m, n = A.shape
p, q = B.shape
n_A = A._nnz()
n_B = B._nnz()
indices_A = A.coalesce().indices()
indices_B = B.coalesce().indices()
indices_A[0, :] = indices_A[0, :] * p
indices_A[1, :] = indices_A[1, :] * q
indices = (indices_A.repeat(n_B, 1) + indices_B.t().reshape(2 * n_B, 1))
ind_row = indices.index_select(0, torch.arange(start=0, end=2 * n_B, step=2, device=A.device)).reshape(-1)
ind_col = indices.index_select(0, torch.arange(start=1, end=2 * n_B, step=2, device=A.device)).reshape(-1)
new_ind = torch.cat((ind_row, ind_col)).reshape(2, n_A * n_B)
values = torch.ones(n_A * n_B).to(A.device)
new_shape = (m * p, n * q)
return torch.sparse_coo_tensor(new_ind, values, new_shape)
[docs]
def aggr(self, S, aggr_optor, n1, n2, scale_mat):
S = torch.sparse.mm(aggr_optor, S.reshape(-1)[:, None]).reshape(n1, n2) * scale_mat
# S += 1e-9
return S
[docs]
def update_sigma(self, S, diag1, diag2):
S = S / diag1[:, None] / diag2[None, :]
S = torch.clip(S, -0.9999, 0.9999)
S = (S * (math.pi - torch.arccos(S)) + torch.sqrt(1 - S * S)) / math.pi
degree_sigma = (math.pi - torch.arccos(S)) / math.pi
S = S * diag1[:, None] * diag2[None, :]
return S, degree_sigma
[docs]
def update_diag(self, S):
diag = torch.sqrt(torch.diag(S))
S = S / diag[:, None] / diag[None, :]
S = torch.clip(S, -0.9999, 0.9999)
S = (S * (math.pi - torch.arccos(S)) + torch.sqrt(1 - S * S)) / math.pi
S = S * diag[:, None] * diag[None, :]
return S, diag
[docs]
def diag(self, g, E):
n = E.shape[0]
aggr_optor = self.sparse_kron(E, E)
if self.scale == 'add':
scale_mat = 1.
else:
scale_mat = (1. / torch.sparse.sum(aggr_optor, 1).to_dense()).reshape(n, n)
diag_list = []
sigma = torch.matmul(g, g.t())
for k in range(self.K):
sigma = self.aggr(sigma, aggr_optor, n, n, scale_mat)
sigma, diag = self.update_diag(sigma)
diag_list.append(diag)
return diag_list
[docs]
def nodes_gram(self, g1, g2, E1, E2):
n1, n2 = len(g1), len(g2)
aggr_optor = self.sparse_kron(E1, E2)
if self.scale == 'add':
scale_mat = 1.
else:
scale_mat = (1. / torch.sparse.sum(aggr_optor, 1).to_dense()).reshape(n1, n2)
sigma = torch.matmul(g1, g2.t())
theta = sigma
diag_list1, diag_list2 = self.diag(g1, E1), self.diag(g2, E2)
for k in range(self.K):
sigma = self.aggr(sigma, aggr_optor, n1, n2, scale_mat)
theta = self.aggr(theta, aggr_optor, n1, n2, scale_mat)
for l in range(self.L):
sigma, degree_sigma = self.update_sigma(sigma, diag_list1[k], diag_list2[k])
theta = theta * degree_sigma + sigma
return theta