File size: 1,898 Bytes
9a67fbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import gpytorch
from gauche.kernels.fingerprint_kernels.tanimoto_kernel import TanimotoKernel
from gpytorch.kernels import ScaleKernel
from gpytorch.means import ConstantMean
from gpytorch.distributions import MultivariateNormal
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super().__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(TanimotoKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x) # type: ignore
class TanimotoGP:
def __init__(self):
self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
self.model = None
self.train_x = None
self.train_y = None
def fit(self, X, y):
self.train_x = torch.tensor(X, dtype=torch.float)
self.train_y = torch.tensor(y, dtype=torch.float)
self.model = ExactGPModel(self.train_x, self.train_y, self.likelihood)
self.model.train()
self.likelihood.train()
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
for i in range(50):
optimizer.zero_grad()
output = self.model(self.train_x)
loss = -1 * mll(output, self.train_y) # type: ignore
loss.backward() # type: ignore
optimizer.step()
def predict(self, X):
self.model.eval() # type: ignore
self.likelihood.eval()
test_x = torch.tensor(X, dtype=torch.float)
with torch.no_grad(), gpytorch.settings.fast_pred_var():
preds = self.likelihood(self.model(test_x)) # type: ignore
return preds.mean.numpy(), preds.variance.numpy()
|