Source code for graphslim.models.krr

import torch
import torch.nn as nn


[docs] class KernelRidgeRegression(nn.Module): def __init__(self, kernel, ridge): super(KernelRidgeRegression, self).__init__() self.kernel = kernel self.ridge = ridge
[docs] def forward(self, G_t, G_s, y_t, y_s, E_t, E_s): K_ss = self.kernel(G_s, G_s, E_s, E_s) K_ts = self.kernel(G_t, G_s, E_t, E_s) n = torch.tensor(len(G_s), device=G_s.device) regulizer = self.ridge * torch.trace(K_ss) * torch.eye(n, device=G_s.device) / n b = torch.linalg.solve(K_ss + regulizer, y_s) pred = torch.matmul(K_ts, b) pred = nn.functional.softmax(pred, dim=1) correct = torch.eq(pred.argmax(1).to(torch.float32), y_t.argmax(1).to(torch.float32)).sum().item() acc = correct / len(y_t) return pred, acc