File size: 1,898 Bytes
9a67fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import gpytorch
from gauche.kernels.fingerprint_kernels.tanimoto_kernel import TanimotoKernel
from gpytorch.kernels import ScaleKernel
from gpytorch.means import ConstantMean
from gpytorch.distributions import MultivariateNormal


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(TanimotoKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)  # type: ignore


class TanimotoGP:
    def __init__(self):
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        self.model = None
        self.train_x = None
        self.train_y = None

    def fit(self, X, y):
        self.train_x = torch.tensor(X, dtype=torch.float)
        self.train_y = torch.tensor(y, dtype=torch.float)

        self.model = ExactGPModel(self.train_x, self.train_y, self.likelihood)
        self.model.train()
        self.likelihood.train()

        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.1)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)

        for i in range(50):
            optimizer.zero_grad()
            output = self.model(self.train_x)
            loss = -1 * mll(output, self.train_y)  # type: ignore
            loss.backward()  # type: ignore
            optimizer.step()

    def predict(self, X):
        self.model.eval()  # type: ignore
        self.likelihood.eval()

        test_x = torch.tensor(X, dtype=torch.float)
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            preds = self.likelihood(self.model(test_x))  # type: ignore
            return preds.mean.numpy(), preds.variance.numpy()